diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml index fd8eaa8d13fc..d8cf414a8cb7 100644 --- a/.github/ISSUE_TEMPLATE/config.yml +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -1,7 +1,4 @@ contact_links: - - name: Forum - url: https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/63 - about: General usage questions and community discussions - name: Blank issue url: https://github.com/huggingface/diffusers/issues/new - about: Please note that the Forum is in most places the right place for discussions + about: General usage questions and community discussions diff --git a/.github/workflows/build_docker_images.yml b/.github/workflows/build_docker_images.yml new file mode 100644 index 000000000000..ff4bd66fdde5 --- /dev/null +++ b/.github/workflows/build_docker_images.yml @@ -0,0 +1,50 @@ +name: Build Docker images (nightly) + +on: + workflow_dispatch: + schedule: + - cron: "0 0 * * *" # every day at midnight + +concurrency: + group: docker-image-builds + cancel-in-progress: false + +env: + REGISTRY: diffusers + +jobs: + build-docker-images: + runs-on: ubuntu-latest + + permissions: + contents: read + packages: write + + strategy: + fail-fast: false + matrix: + image-name: + - diffusers-pytorch-cpu + - diffusers-pytorch-cuda + - diffusers-flax-cpu + - diffusers-flax-tpu + - diffusers-onnxruntime-cpu + - diffusers-onnxruntime-cuda + + steps: + - name: Checkout repository + uses: actions/checkout@v3 + + - name: Login to Docker Hub + uses: docker/login-action@v2 + with: + username: ${{ env.REGISTRY }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + + - name: Build and push + uses: docker/build-push-action@v3 + with: + no-cache: true + context: ./docker/${{ matrix.image-name }} + push: true + tags: ${{ env.REGISTRY }}/${{ matrix.image-name }}:latest diff --git a/.github/workflows/pr_quality.yml b/.github/workflows/pr_quality.yml index 5a850bea47a1..8d6e20efe364 100644 --- a/.github/workflows/pr_quality.yml +++ b/.github/workflows/pr_quality.yml @@ -31,3 +31,20 @@ jobs: isort --check-only examples tests src utils scripts flake8 examples tests src utils scripts doc-builder style src/diffusers docs/source --max_len 119 --check_only --path_to_docs docs/source + + check_repository_consistency: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: "3.7" + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install .[quality] + - name: Check quality + run: | + python utils/check_copies.py + python utils/check_dummies.py diff --git a/.github/workflows/pr_tests.yml b/.github/workflows/pr_tests.yml index 493ff51484b0..55a9bd68deb3 100644 --- a/.github/workflows/pr_tests.yml +++ b/.github/workflows/pr_tests.yml @@ -10,19 +10,46 @@ concurrency: cancel-in-progress: true env: - OMP_NUM_THREADS: 8 - MKL_NUM_THREADS: 8 + DIFFUSERS_IS_CI: yes + OMP_NUM_THREADS: 4 + MKL_NUM_THREADS: 4 PYTEST_TIMEOUT: 60 MPS_TORCH_VERSION: 1.13.0 jobs: - run_tests_cpu: - name: CPU tests on Ubuntu - runs-on: [ self-hosted, docker-gpu ] + run_fast_tests: + strategy: + fail-fast: false + matrix: + config: + - name: Fast PyTorch CPU tests on Ubuntu + framework: pytorch + runner: docker-cpu + image: diffusers/diffusers-pytorch-cpu + report: torch_cpu + - name: Fast Flax CPU tests on Ubuntu + framework: flax + runner: docker-cpu + image: diffusers/diffusers-flax-cpu + report: flax_cpu + - name: Fast ONNXRuntime CPU tests on Ubuntu + framework: onnxruntime + runner: docker-cpu + image: diffusers/diffusers-onnxruntime-cpu + report: onnx_cpu + + name: ${{ matrix.config.name }} + + runs-on: ${{ matrix.config.runner }} + container: - image: python:3.7 + image: ${{ matrix.config.image }} options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/ + defaults: + run: + shell: bash + steps: - name: Checkout diffusers uses: actions/checkout@v3 @@ -31,31 +58,51 @@ jobs: - name: Install dependencies run: | - python -m pip install --upgrade pip - python -m pip install torch --extra-index-url https://download.pytorch.org/whl/cpu python -m pip install -e .[quality,test] + python -m pip install git+https://github.com/huggingface/accelerate + python -m pip install -U git+https://github.com/huggingface/transformers - name: Environment run: | python utils/print_env.py - - name: Run all fast tests on CPU + - name: Run fast PyTorch CPU tests + if: ${{ matrix.config.framework == 'pytorch' }} + run: | + python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \ + -s -v -k "not Flax and not Onnx" \ + --make-reports=tests_${{ matrix.config.report }} \ + tests/ + + - name: Run fast Flax TPU tests + if: ${{ matrix.config.framework == 'flax' }} + run: | + python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \ + -s -v -k "Flax" \ + --make-reports=tests_${{ matrix.config.report }} \ + tests/ + + - name: Run fast ONNXRuntime CPU tests + if: ${{ matrix.config.framework == 'onnxruntime' }} run: | - python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=tests_torch_cpu tests/ + python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \ + -s -v -k "Onnx" \ + --make-reports=tests_${{ matrix.config.report }} \ + tests/ - name: Failure short reports if: ${{ failure() }} - run: cat reports/tests_torch_cpu_failures_short.txt + run: cat reports/tests_${{ matrix.config.report }}_failures_short.txt - name: Test suite reports artifacts if: ${{ always() }} uses: actions/upload-artifact@v2 with: - name: pr_torch_cpu_test_reports + name: pr_${{ matrix.config.report }}_test_reports path: reports - run_tests_apple_m1: - name: MPS tests on Apple M1 + run_fast_tests_apple_m1: + name: Fast PyTorch MPS tests on MacOS runs-on: [ self-hosted, apple-m1 ] steps: @@ -80,16 +127,18 @@ jobs: ${CONDA_RUN} python -m pip install --upgrade pip ${CONDA_RUN} python -m pip install -e .[quality,test] ${CONDA_RUN} python -m pip install --pre torch==${MPS_TORCH_VERSION} --extra-index-url https://download.pytorch.org/whl/test/cpu + ${CONDA_RUN} python -m pip install git+https://github.com/huggingface/accelerate + ${CONDA_RUN} python -m pip install -U git+https://github.com/huggingface/transformers - name: Environment shell: arch -arch arm64 bash {0} run: | ${CONDA_RUN} python utils/print_env.py - - name: Run all fast tests on MPS + - name: Run fast PyTorch tests on M1 (MPS) shell: arch -arch arm64 bash {0} run: | - ${CONDA_RUN} python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=tests_torch_mps tests/ + ${CONDA_RUN} python -m pytest -n 0 -s -v --make-reports=tests_torch_mps tests/ - name: Failure short reports if: ${{ failure() }} diff --git a/.github/workflows/push_tests.yml b/.github/workflows/push_tests.yml index 3e4a81c91c01..4bab00b7eeb0 100644 --- a/.github/workflows/push_tests.yml +++ b/.github/workflows/push_tests.yml @@ -6,6 +6,7 @@ on: - main env: + DIFFUSERS_IS_CI: yes HF_HOME: /mnt/cache OMP_NUM_THREADS: 8 MKL_NUM_THREADS: 8 @@ -13,12 +14,38 @@ env: RUN_SLOW: yes jobs: - run_tests_single_gpu: - name: Diffusers tests - runs-on: [ self-hosted, docker-gpu, single-gpu ] + run_slow_tests: + strategy: + fail-fast: false + matrix: + config: + - name: Slow PyTorch CUDA tests on Ubuntu + framework: pytorch + runner: docker-gpu + image: diffusers/diffusers-pytorch-cuda + report: torch_cuda + - name: Slow Flax TPU tests on Ubuntu + framework: flax + runner: docker-tpu + image: diffusers/diffusers-flax-tpu + report: flax_tpu + - name: Slow ONNXRuntime CUDA tests on Ubuntu + framework: onnxruntime + runner: docker-gpu + image: diffusers/diffusers-onnxruntime-cuda + report: onnx_cuda + + name: ${{ matrix.config.name }} + + runs-on: ${{ matrix.config.runner }} + container: - image: nvcr.io/nvidia/pytorch:22.07-py3 - options: --gpus 0 --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache + image: ${{ matrix.config.image }} + options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/ ${{ matrix.config.runner == 'docker-tpu' && '--privileged' || '--gpus 0'}} + + defaults: + run: + shell: bash steps: - name: Checkout diffusers @@ -27,45 +54,69 @@ jobs: fetch-depth: 2 - name: NVIDIA-SMI + if : ${{ matrix.config.runner == 'docker-gpu' }} run: | nvidia-smi - name: Install dependencies run: | - python -m pip install --upgrade pip - python -m pip uninstall -y torch torchvision torchtext - python -m pip install torch --extra-index-url https://download.pytorch.org/whl/cu116 python -m pip install -e .[quality,test] + python -m pip install git+https://github.com/huggingface/accelerate + python -m pip install -U git+https://github.com/huggingface/transformers - name: Environment run: | python utils/print_env.py - - name: Run all (incl. slow) tests on GPU + - name: Run slow PyTorch CUDA tests + if: ${{ matrix.config.framework == 'pytorch' }} + env: + HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} + run: | + python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \ + -s -v -k "not Flax and not Onnx" \ + --make-reports=tests_${{ matrix.config.report }} \ + tests/ + + - name: Run slow Flax TPU tests + if: ${{ matrix.config.framework == 'flax' }} + env: + HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} + run: | + python -m pytest -n 0 \ + -s -v -k "Flax" \ + --make-reports=tests_${{ matrix.config.report }} \ + tests/ + + - name: Run slow ONNXRuntime CUDA tests + if: ${{ matrix.config.framework == 'onnxruntime' }} env: HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} run: | - python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=tests_torch_gpu tests/ + python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \ + -s -v -k "Onnx" \ + --make-reports=tests_${{ matrix.config.report }} \ + tests/ - name: Failure short reports if: ${{ failure() }} - run: cat reports/tests_torch_gpu_failures_short.txt + run: cat reports/tests_${{ matrix.config.report }}_failures_short.txt - name: Test suite reports artifacts if: ${{ always() }} uses: actions/upload-artifact@v2 with: - name: torch_test_reports + name: ${{ matrix.config.report }}_test_reports path: reports + run_examples_tests: + name: Examples PyTorch CUDA tests on Ubuntu + runs-on: docker-gpu - run_examples_single_gpu: - name: Examples tests - runs-on: [ self-hosted, docker-gpu, single-gpu ] container: - image: nvcr.io/nvidia/pytorch:22.07-py3 - options: --gpus 0 --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache + image: diffusers/diffusers-pytorch-cuda + options: --gpus 0 --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/ steps: - name: Checkout diffusers @@ -79,10 +130,9 @@ jobs: - name: Install dependencies run: | - python -m pip install --upgrade pip - python -m pip uninstall -y torch torchvision torchtext - python -m pip install torch --extra-index-url https://download.pytorch.org/whl/cu116 python -m pip install -e .[quality,test,training] + python -m pip install git+https://github.com/huggingface/accelerate + python -m pip install -U git+https://github.com/huggingface/transformers - name: Environment run: | @@ -92,11 +142,11 @@ jobs: env: HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} run: | - python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=examples_torch_gpu examples/ + python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=examples_torch_cuda examples/ - name: Failure short reports if: ${{ failure() }} - run: cat reports/examples_torch_gpu_failures_short.txt + run: cat reports/examples_torch_cuda_failures_short.txt - name: Test suite reports artifacts if: ${{ always() }} diff --git a/.gitignore b/.gitignore index cf8183463613..f018a111ea33 100644 --- a/.gitignore +++ b/.gitignore @@ -163,4 +163,6 @@ tags *.lock # DS_Store (MacOS) -.DS_Store \ No newline at end of file +.DS_Store +# RL pipelines may produce mp4 outputs +*.mp4 \ No newline at end of file diff --git a/Makefile b/Makefile index 6e513e2e4743..ea7537f2fd85 100644 --- a/Makefile +++ b/Makefile @@ -67,6 +67,7 @@ fixup: modified_only_fixup extra_style_checks autogenerate_code repo-consistency # Make marked copies of snippets of codes conform to the original fix-copies: + python utils/check_copies.py --fix_and_overwrite python utils/check_dummies.py --fix_and_overwrite # Run tests for the library diff --git a/docker/diffusers-flax-cpu/Dockerfile b/docker/diffusers-flax-cpu/Dockerfile new file mode 100644 index 000000000000..a4b4ccd65b39 --- /dev/null +++ b/docker/diffusers-flax-cpu/Dockerfile @@ -0,0 +1,42 @@ +FROM ubuntu:20.04 +LABEL maintainer="Hugging Face" +LABEL repository="diffusers" + +ENV DEBIAN_FRONTEND=noninteractive + +RUN apt update && \ + apt install -y bash \ + build-essential \ + git \ + git-lfs \ + curl \ + ca-certificates \ + python3.8 \ + python3-pip \ + python3.8-venv && \ + rm -rf /var/lib/apt/lists + +# make sure to use venv +RUN python3 -m venv /opt/venv +ENV PATH="/opt/venv/bin:$PATH" + +# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py) +# follow the instructions here: https://cloud.google.com/tpu/docs/run-in-container#train_a_jax_model_in_a_docker_container +RUN python3 -m pip install --no-cache-dir --upgrade pip && \ + python3 -m pip install --upgrade --no-cache-dir \ + clu \ + "jax[cpu]>=0.2.16,!=0.3.2" \ + "flax>=0.4.1" \ + "jaxlib>=0.1.65" && \ + python3 -m pip install --no-cache-dir \ + accelerate \ + datasets \ + hf-doc-builder \ + huggingface-hub \ + modelcards \ + numpy \ + scipy \ + tensorboard \ + transformers + +CMD ["/bin/bash"] \ No newline at end of file diff --git a/docker/diffusers-flax-tpu/Dockerfile b/docker/diffusers-flax-tpu/Dockerfile new file mode 100644 index 000000000000..5508af6622dd --- /dev/null +++ b/docker/diffusers-flax-tpu/Dockerfile @@ -0,0 +1,44 @@ +FROM ubuntu:20.04 +LABEL maintainer="Hugging Face" +LABEL repository="diffusers" + +ENV DEBIAN_FRONTEND=noninteractive + +RUN apt update && \ + apt install -y bash \ + build-essential \ + git \ + git-lfs \ + curl \ + ca-certificates \ + python3.8 \ + python3-pip \ + python3.8-venv && \ + rm -rf /var/lib/apt/lists + +# make sure to use venv +RUN python3 -m venv /opt/venv +ENV PATH="/opt/venv/bin:$PATH" + +# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py) +# follow the instructions here: https://cloud.google.com/tpu/docs/run-in-container#train_a_jax_model_in_a_docker_container +RUN python3 -m pip install --no-cache-dir --upgrade pip && \ + python3 -m pip install --no-cache-dir \ + "jax[tpu]>=0.2.16,!=0.3.2" \ + -f https://storage.googleapis.com/jax-releases/libtpu_releases.html && \ + python3 -m pip install --upgrade --no-cache-dir \ + clu \ + "flax>=0.4.1" \ + "jaxlib>=0.1.65" && \ + python3 -m pip install --no-cache-dir \ + accelerate \ + datasets \ + hf-doc-builder \ + huggingface-hub \ + modelcards \ + numpy \ + scipy \ + tensorboard \ + transformers + +CMD ["/bin/bash"] \ No newline at end of file diff --git a/docker/diffusers-onnxruntime-cpu/Dockerfile b/docker/diffusers-onnxruntime-cpu/Dockerfile new file mode 100644 index 000000000000..c925715915cd --- /dev/null +++ b/docker/diffusers-onnxruntime-cpu/Dockerfile @@ -0,0 +1,42 @@ +FROM ubuntu:20.04 +LABEL maintainer="Hugging Face" +LABEL repository="diffusers" + +ENV DEBIAN_FRONTEND=noninteractive + +RUN apt update && \ + apt install -y bash \ + build-essential \ + git \ + git-lfs \ + curl \ + ca-certificates \ + python3.8 \ + python3-pip \ + python3.8-venv && \ + rm -rf /var/lib/apt/lists + +# make sure to use venv +RUN python3 -m venv /opt/venv +ENV PATH="/opt/venv/bin:$PATH" + +# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py) +RUN python3 -m pip install --no-cache-dir --upgrade pip && \ + python3 -m pip install --no-cache-dir \ + torch \ + torchvision \ + torchaudio \ + onnxruntime \ + --extra-index-url https://download.pytorch.org/whl/cpu && \ + python3 -m pip install --no-cache-dir \ + accelerate \ + datasets \ + hf-doc-builder \ + huggingface-hub \ + modelcards \ + numpy \ + scipy \ + tensorboard \ + transformers + +CMD ["/bin/bash"] \ No newline at end of file diff --git a/docker/diffusers-onnxruntime-cuda/Dockerfile b/docker/diffusers-onnxruntime-cuda/Dockerfile new file mode 100644 index 000000000000..e51a5e0ba30f --- /dev/null +++ b/docker/diffusers-onnxruntime-cuda/Dockerfile @@ -0,0 +1,42 @@ +FROM nvidia/cuda:11.6.2-cudnn8-devel-ubuntu20.04 +LABEL maintainer="Hugging Face" +LABEL repository="diffusers" + +ENV DEBIAN_FRONTEND=noninteractive + +RUN apt update && \ + apt install -y bash \ + build-essential \ + git \ + git-lfs \ + curl \ + ca-certificates \ + python3.8 \ + python3-pip \ + python3.8-venv && \ + rm -rf /var/lib/apt/lists + +# make sure to use venv +RUN python3 -m venv /opt/venv +ENV PATH="/opt/venv/bin:$PATH" + +# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py) +RUN python3 -m pip install --no-cache-dir --upgrade pip && \ + python3 -m pip install --no-cache-dir \ + torch \ + torchvision \ + torchaudio \ + "onnxruntime-gpu>=1.13.1" \ + --extra-index-url https://download.pytorch.org/whl/cu117 && \ + python3 -m pip install --no-cache-dir \ + accelerate \ + datasets \ + hf-doc-builder \ + huggingface-hub \ + modelcards \ + numpy \ + scipy \ + tensorboard \ + transformers + +CMD ["/bin/bash"] \ No newline at end of file diff --git a/docker/diffusers-pytorch-cpu/Dockerfile b/docker/diffusers-pytorch-cpu/Dockerfile new file mode 100644 index 000000000000..41d1672f60e6 --- /dev/null +++ b/docker/diffusers-pytorch-cpu/Dockerfile @@ -0,0 +1,41 @@ +FROM ubuntu:20.04 +LABEL maintainer="Hugging Face" +LABEL repository="diffusers" + +ENV DEBIAN_FRONTEND=noninteractive + +RUN apt update && \ + apt install -y bash \ + build-essential \ + git \ + git-lfs \ + curl \ + ca-certificates \ + python3.8 \ + python3-pip \ + python3.8-venv && \ + rm -rf /var/lib/apt/lists + +# make sure to use venv +RUN python3 -m venv /opt/venv +ENV PATH="/opt/venv/bin:$PATH" + +# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py) +RUN python3 -m pip install --no-cache-dir --upgrade pip && \ + python3 -m pip install --no-cache-dir \ + torch \ + torchvision \ + torchaudio \ + --extra-index-url https://download.pytorch.org/whl/cpu && \ + python3 -m pip install --no-cache-dir \ + accelerate \ + datasets \ + hf-doc-builder \ + huggingface-hub \ + modelcards \ + numpy \ + scipy \ + tensorboard \ + transformers + +CMD ["/bin/bash"] \ No newline at end of file diff --git a/docker/diffusers-pytorch-cuda/Dockerfile b/docker/diffusers-pytorch-cuda/Dockerfile new file mode 100644 index 000000000000..ba80395c89f4 --- /dev/null +++ b/docker/diffusers-pytorch-cuda/Dockerfile @@ -0,0 +1,41 @@ +FROM nvidia/cuda:11.7.1-cudnn8-runtime-ubuntu20.04 +LABEL maintainer="Hugging Face" +LABEL repository="diffusers" + +ENV DEBIAN_FRONTEND=noninteractive + +RUN apt update && \ + apt install -y bash \ + build-essential \ + git \ + git-lfs \ + curl \ + ca-certificates \ + python3.8 \ + python3-pip \ + python3.8-venv && \ + rm -rf /var/lib/apt/lists + +# make sure to use venv +RUN python3 -m venv /opt/venv +ENV PATH="/opt/venv/bin:$PATH" + +# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py) +RUN python3 -m pip install --no-cache-dir --upgrade pip && \ + python3 -m pip install --no-cache-dir \ + torch \ + torchvision \ + torchaudio \ + --extra-index-url https://download.pytorch.org/whl/cu117 && \ + python3 -m pip install --no-cache-dir \ + accelerate \ + datasets \ + hf-doc-builder \ + huggingface-hub \ + modelcards \ + numpy \ + scipy \ + tensorboard \ + transformers + +CMD ["/bin/bash"] \ No newline at end of file diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 390b8ff042db..957144488323 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -10,11 +10,13 @@ - sections: - local: using-diffusers/loading title: "Loading Pipelines, Models, and Schedulers" + - local: using-diffusers/schedulers + title: "Using different Schedulers" - local: using-diffusers/configuration title: "Configuring Pipelines, Models, and Schedulers" - - local: using-diffusers/custom_pipelines - title: "Loading and Creating Custom Pipelines" - title: "Loading" + - local: using-diffusers/custom_pipeline_overview + title: "Loading and Adding Custom Pipelines" + title: "Loading & Hub" - sections: - local: using-diffusers/unconditional_image_generation title: "Unconditional Image Generation" @@ -24,9 +26,19 @@ title: "Text-Guided Image-to-Image" - local: using-diffusers/inpaint title: "Text-Guided Image-Inpainting" - - local: using-diffusers/custom - title: "Create a custom pipeline" + - local: using-diffusers/custom_pipeline_examples + title: "Community Pipelines" + - local: using-diffusers/contribute_pipeline + title: "How to contribute a Pipeline" title: "Pipelines for Inference" + - sections: + - local: using-diffusers/rl + title: "Reinforcement Learning" + - local: using-diffusers/audio + title: "Audio" + - local: using-diffusers/other-modalities + title: "Other Modalities" + title: "Taking Diffusers Beyond Images" title: "Using Diffusers" - sections: - local: optimization/fp16 @@ -34,7 +46,7 @@ - local: optimization/onnx title: "ONNX" - local: optimization/open_vino - title: "Open Vino" + title: "OpenVINO" - local: optimization/mps title: "MPS" title: "Optimization/Special Hardware" @@ -44,9 +56,11 @@ - local: training/unconditional_training title: "Unconditional Image Generation" - local: training/text_inversion - title: "Text Inversion" + title: "Textual Inversion" + - local: training/dreambooth + title: "Dreambooth" - local: training/text2image - title: "Text-to-image" + title: "Text-to-image fine-tuning" title: "Training" - sections: - local: conceptual/stable_diffusion @@ -74,6 +88,10 @@ - sections: - local: api/pipelines/overview title: "Overview" + - local: api/pipelines/alt_diffusion + title: "AltDiffusion" + - local: api/pipelines/cycle_diffusion + title: "Cycle Diffusion" - local: api/pipelines/ddim title: "DDIM" - local: api/pipelines/ddpm @@ -88,7 +106,23 @@ title: "Score SDE VE" - local: api/pipelines/stable_diffusion title: "Stable Diffusion" + - local: api/pipelines/stable_diffusion_2 + title: "Stable Diffusion 2" + - local: api/pipelines/stable_diffusion_safe + title: "Safe Stable Diffusion" - local: api/pipelines/stochastic_karras_ve title: "Stochastic Karras VE" + - local: api/pipelines/dance_diffusion + title: "Dance Diffusion" + - local: api/pipelines/versatile_diffusion + title: "Versatile Diffusion" + - local: api/pipelines/vq_diffusion + title: "VQ Diffusion" + - local: api/pipelines/repaint + title: "RePaint" title: "Pipelines" + - sections: + - local: api/experimental/rl + title: "RL Planning" + title: "Experimental Features" title: "API" diff --git a/docs/source/api/configuration.mdx b/docs/source/api/configuration.mdx index 45176f55b018..423c31f462b6 100644 --- a/docs/source/api/configuration.mdx +++ b/docs/source/api/configuration.mdx @@ -15,9 +15,9 @@ specific language governing permissions and limitations under the License. In Diffusers, schedulers of type [`schedulers.scheduling_utils.SchedulerMixin`], and models of type [`ModelMixin`] inherit from [`ConfigMixin`] which conveniently takes care of storing all parameters that are passed to the respective `__init__` methods in a JSON-configuration file. -TODO(PVP) - add example and better info here - ## ConfigMixin + [[autodoc]] ConfigMixin + - load_config - from_config - save_config diff --git a/docs/source/api/diffusion_pipeline.mdx b/docs/source/api/diffusion_pipeline.mdx index 6a0f758f7694..b037b4e26dc1 100644 --- a/docs/source/api/diffusion_pipeline.mdx +++ b/docs/source/api/diffusion_pipeline.mdx @@ -32,6 +32,9 @@ Any pipeline object can be saved locally with [`~DiffusionPipeline.save_pretrain [[autodoc]] DiffusionPipeline - from_pretrained - save_pretrained + - to + - device + - components ## ImagePipelineOutput By default diffusion pipelines return an object of class diff --git a/docs/source/using-diffusers/custom.mdx b/docs/source/api/experimental/rl.mdx similarity index 93% rename from docs/source/using-diffusers/custom.mdx rename to docs/source/api/experimental/rl.mdx index 2d17adaff711..65abb06e7523 100644 --- a/docs/source/using-diffusers/custom.mdx +++ b/docs/source/api/experimental/rl.mdx @@ -10,6 +10,6 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o specific language governing permissions and limitations under the License. --> -# Custom Pipeline +# TODO -Under construction 🚧 +Coming soon! \ No newline at end of file diff --git a/docs/source/api/models.mdx b/docs/source/api/models.mdx index c92fdccb8333..7c1faa847440 100644 --- a/docs/source/api/models.mdx +++ b/docs/source/api/models.mdx @@ -25,6 +25,12 @@ The models are built on the base class ['ModelMixin'] that is a `torch.nn.module ## UNet2DModel [[autodoc]] UNet2DModel +## UNet1DOutput +[[autodoc]] models.unet_1d.UNet1DOutput + +## UNet1DModel +[[autodoc]] UNet1DModel + ## UNet2DConditionOutput [[autodoc]] models.unet_2d_condition.UNet2DConditionOutput @@ -46,6 +52,12 @@ The models are built on the base class ['ModelMixin'] that is a `torch.nn.module ## AutoencoderKL [[autodoc]] AutoencoderKL +## Transformer2DModel +[[autodoc]] Transformer2DModel + +## Transformer2DModelOutput +[[autodoc]] models.attention.Transformer2DModelOutput + ## FlaxModelMixin [[autodoc]] FlaxModelMixin diff --git a/docs/source/api/pipelines/alt_diffusion.mdx b/docs/source/api/pipelines/alt_diffusion.mdx new file mode 100644 index 000000000000..8d7d795d7633 --- /dev/null +++ b/docs/source/api/pipelines/alt_diffusion.mdx @@ -0,0 +1,83 @@ + + +# AltDiffusion + +AltDiffusion was proposed in [AltCLIP: Altering the Language Encoder in CLIP for Extended Language Capabilities](https://arxiv.org/abs/2211.06679) by Zhongzhi Chen, Guang Liu, Bo-Wen Zhang, Fulong Ye, Qinghong Yang, Ledell Wu + +The abstract of the paper is the following: + +*In this work, we present a conceptually simple and effective method to train a strong bilingual multimodal representation model. Starting from the pretrained multimodal representation model CLIP released by OpenAI, we switched its text encoder with a pretrained multilingual text encoder XLM-R, and aligned both languages and image representations by a two-stage training schema consisting of teacher learning and contrastive learning. We validate our method through evaluations of a wide range of tasks. We set new state-of-the-art performances on a bunch of tasks including ImageNet-CN, Flicker30k- CN, and COCO-CN. Further, we obtain very close performances with CLIP on almost all tasks, suggesting that one can simply alter the text encoder in CLIP for extended capabilities such as multilingual understanding.* + + +*Overview*: + +| Pipeline | Tasks | Colab | Demo +|---|---|:---:|:---:| +| [pipeline_alt_diffusion.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py) | *Text-to-Image Generation* | - | - +| [pipeline_alt_diffusion_img2img.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py) | *Image-to-Image Text-Guided Generation* | - |- + +## Tips + +- AltDiffusion is conceptually exaclty the same as [Stable Diffusion](./api/pipelines/stable_diffusion). + +- *Run AltDiffusion* + +AltDiffusion can be tested very easily with the [`AltDiffusionPipeline`], [`AltDiffusionImg2ImgPipeline`] and the `"BAAI/AltDiffusion-m9"` checkpoint exactly in the same way it is shown in the [Conditional Image Generation Guide](./using-diffusers/conditional_image_generation) and the [Image-to-Image Generation Guide](./using-diffusers/img2img). + +- *How to load and use different schedulers.* + +The alt diffusion pipeline uses [`DDIMScheduler`] scheduler by default. But `diffusers` provides many other schedulers that can be used with the alt diffusion pipeline such as [`PNDMScheduler`], [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`] etc. +To use a different scheduler, you can either change it via the [`ConfigMixin.from_config`] method or pass the `scheduler` argument to the `from_pretrained` method of the pipeline. For example, to use the [`EulerDiscreteScheduler`], you can do the following: + +```python +>>> from diffusers import AltDiffusionPipeline, EulerDiscreteScheduler + +>>> pipeline = AltDiffusionPipeline.from_pretrained("BAAI/AltDiffusion-m9") +>>> pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config) + +>>> # or +>>> euler_scheduler = EulerDiscreteScheduler.from_pretrained("BAAI/AltDiffusion-m9", subfolder="scheduler") +>>> pipeline = AltDiffusionPipeline.from_pretrained("BAAI/AltDiffusion-m9", scheduler=euler_scheduler) +``` + + +- *How to convert all use cases with multiple or single pipeline* + +If you want to use all possible use cases in a single `DiffusionPipeline` we recommend using the `components` functionality to instantiate all components in the most memory-efficient way: + +```python +>>> from diffusers import ( +... AltDiffusionPipeline, +... AltDiffusionImg2ImgPipeline, +... ) + +>>> text2img = AltDiffusionPipeline.from_pretrained("BAAI/AltDiffusion-m9") +>>> img2img = AltDiffusionImg2ImgPipeline(**text2img.components) + +>>> # now you can use text2img(...) and img2img(...) just like the call methods of each respective pipeline +``` + +## AltDiffusionPipelineOutput +[[autodoc]] pipelines.alt_diffusion.AltDiffusionPipelineOutput + +## AltDiffusionPipeline +[[autodoc]] AltDiffusionPipeline + - __call__ + - enable_attention_slicing + - disable_attention_slicing + +## AltDiffusionImg2ImgPipeline +[[autodoc]] AltDiffusionImg2ImgPipeline + - __call__ + - enable_attention_slicing + - disable_attention_slicing diff --git a/docs/source/api/pipelines/cycle_diffusion.mdx b/docs/source/api/pipelines/cycle_diffusion.mdx new file mode 100644 index 000000000000..8eecd3d62494 --- /dev/null +++ b/docs/source/api/pipelines/cycle_diffusion.mdx @@ -0,0 +1,99 @@ + + +# Cycle Diffusion + +## Overview + +Cycle Diffusion is a Text-Guided Image-to-Image Generation model proposed in [Unifying Diffusion Models' Latent Space, with Applications to CycleDiffusion and Guidance](https://arxiv.org/abs/2210.05559) by Chen Henry Wu, Fernando De la Torre. + +The abstract of the paper is the following: + +*Diffusion models have achieved unprecedented performance in generative modeling. The commonly-adopted formulation of the latent code of diffusion models is a sequence of gradually denoised samples, as opposed to the simpler (e.g., Gaussian) latent space of GANs, VAEs, and normalizing flows. This paper provides an alternative, Gaussian formulation of the latent space of various diffusion models, as well as an invertible DPM-Encoder that maps images into the latent space. While our formulation is purely based on the definition of diffusion models, we demonstrate several intriguing consequences. (1) Empirically, we observe that a common latent space emerges from two diffusion models trained independently on related domains. In light of this finding, we propose CycleDiffusion, which uses DPM-Encoder for unpaired image-to-image translation. Furthermore, applying CycleDiffusion to text-to-image diffusion models, we show that large-scale text-to-image diffusion models can be used as zero-shot image-to-image editors. (2) One can guide pre-trained diffusion models and GANs by controlling the latent codes in a unified, plug-and-play formulation based on energy-based models. Using the CLIP model and a face recognition model as guidance, we demonstrate that diffusion models have better coverage of low-density sub-populations and individuals than GANs.* + +*Tips*: +- The Cycle Diffusion pipeline is fully compatible with any [Stable Diffusion](./stable_diffusion) checkpoints +- Currently Cycle Diffusion only works with the [`DDIMScheduler`]. + +*Example*: + +In the following we should how to best use the [`CycleDiffusionPipeline`] + +```python +import requests +import torch +from PIL import Image +from io import BytesIO + +from diffusers import CycleDiffusionPipeline, DDIMScheduler + +# load the pipeline +# make sure you're logged in with `huggingface-cli login` +model_id_or_path = "CompVis/stable-diffusion-v1-4" +scheduler = DDIMScheduler.from_pretrained(model_id_or_path, subfolder="scheduler") +pipe = CycleDiffusionPipeline.from_pretrained(model_id_or_path, scheduler=scheduler).to("cuda") + +# let's download an initial image +url = "https://raw.githubusercontent.com/ChenWu98/cycle-diffusion/main/data/dalle2/An%20astronaut%20riding%20a%20horse.png" +response = requests.get(url) +init_image = Image.open(BytesIO(response.content)).convert("RGB") +init_image = init_image.resize((512, 512)) +init_image.save("horse.png") + +# let's specify a prompt +source_prompt = "An astronaut riding a horse" +prompt = "An astronaut riding an elephant" + +# call the pipeline +image = pipe( + prompt=prompt, + source_prompt=source_prompt, + init_image=init_image, + num_inference_steps=100, + eta=0.1, + strength=0.8, + guidance_scale=2, + source_guidance_scale=1, +).images[0] + +image.save("horse_to_elephant.png") + +# let's try another example +# See more samples at the original repo: https://github.com/ChenWu98/cycle-diffusion +url = "https://raw.githubusercontent.com/ChenWu98/cycle-diffusion/main/data/dalle2/A%20black%20colored%20car.png" +response = requests.get(url) +init_image = Image.open(BytesIO(response.content)).convert("RGB") +init_image = init_image.resize((512, 512)) +init_image.save("black.png") + +source_prompt = "A black colored car" +prompt = "A blue colored car" + +# call the pipeline +torch.manual_seed(0) +image = pipe( + prompt=prompt, + source_prompt=source_prompt, + init_image=init_image, + num_inference_steps=100, + eta=0.1, + strength=0.85, + guidance_scale=3, + source_guidance_scale=1, +).images[0] + +image.save("black_to_blue.png") +``` + +## CycleDiffusionPipeline +[[autodoc]] CycleDiffusionPipeline + - __call__ diff --git a/docs/source/api/pipelines/dance_diffusion.mdx b/docs/source/api/pipelines/dance_diffusion.mdx new file mode 100644 index 000000000000..4d969bf6f032 --- /dev/null +++ b/docs/source/api/pipelines/dance_diffusion.mdx @@ -0,0 +1,33 @@ + + +# Dance Diffusion + +## Overview + +[Dance Diffusion](https://github.com/Harmonai-org/sample-generator) by Zach Evans. + +Dance Diffusion is the first in a suite of generative audio tools for producers and musicians to be released by Harmonai. +For more info or to get involved in the development of these tools, please visit https://harmonai.org and fill out the form on the front page. + +The original codebase of this implementation can be found [here](https://github.com/Harmonai-org/sample-generator). + +## Available Pipelines: + +| Pipeline | Tasks | Colab +|---|---|:---:| +| [pipeline_dance_diffusion.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py) | *Unconditional Audio Generation* | - | + + +## DanceDiffusionPipeline +[[autodoc]] DanceDiffusionPipeline + - __call__ diff --git a/docs/source/api/pipelines/ddim.mdx b/docs/source/api/pipelines/ddim.mdx index 41952c2c0dfe..a7a5421b36fe 100644 --- a/docs/source/api/pipelines/ddim.mdx +++ b/docs/source/api/pipelines/ddim.mdx @@ -1,3 +1,15 @@ + + # DDIM ## Overview @@ -8,7 +20,8 @@ The abstract of the paper is the following: Denoising diffusion probabilistic models (DDPMs) have achieved high quality image generation without adversarial training, yet they require simulating a Markov chain for many steps to produce a sample. To accelerate sampling, we present denoising diffusion implicit models (DDIMs), a more efficient class of iterative implicit probabilistic models with the same training procedure as DDPMs. In DDPMs, the generative process is defined as the reverse of a Markovian diffusion process. We construct a class of non-Markovian diffusion processes that lead to the same training objective, but whose reverse process can be much faster to sample from. We empirically demonstrate that DDIMs can produce high quality samples 10× to 50× faster in terms of wall-clock time compared to DDPMs, allow us to trade off computation for sample quality, and can perform semantically meaningful image interpolation directly in the latent space. -The original codebase of this paper can be found [here](https://github.com/ermongroup/ddim). +The original codebase of this paper can be found here: [ermongroup/ddim](https://github.com/ermongroup/ddim). +For questions, feel free to contact the author on [tsong.me](https://tsong.me/). ## Available Pipelines: diff --git a/docs/source/api/pipelines/ddpm.mdx b/docs/source/api/pipelines/ddpm.mdx index b0e08f84ef74..c6d8a6f28660 100644 --- a/docs/source/api/pipelines/ddpm.mdx +++ b/docs/source/api/pipelines/ddpm.mdx @@ -1,3 +1,15 @@ + + # DDPM ## Overview diff --git a/docs/source/api/pipelines/latent_diffusion.mdx b/docs/source/api/pipelines/latent_diffusion.mdx index 821ad4fea5b2..370d014f5a10 100644 --- a/docs/source/api/pipelines/latent_diffusion.mdx +++ b/docs/source/api/pipelines/latent_diffusion.mdx @@ -1,3 +1,15 @@ + + # Latent Diffusion ## Overview @@ -21,10 +33,15 @@ The original codebase can be found [here](https://github.com/CompVis/latent-diff | Pipeline | Tasks | Colab |---|---|:---:| | [pipeline_latent_diffusion.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py) | *Text-to-Image Generation* | - | +| [pipeline_latent_diffusion_superresolution.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py) | *Super Resolution* | - | ## Examples: ## LDMTextToImagePipeline -[[autodoc]] pipelines.latent_diffusion.pipeline_latent_diffusion.LDMTextToImagePipeline +[[autodoc]] LDMTextToImagePipeline + - __call__ + +## LDMSuperResolutionPipeline +[[autodoc]] LDMSuperResolutionPipeline - __call__ diff --git a/docs/source/api/pipelines/latent_diffusion_uncond.mdx b/docs/source/api/pipelines/latent_diffusion_uncond.mdx index 4e12fa7f525c..0a5b20cd4a1c 100644 --- a/docs/source/api/pipelines/latent_diffusion_uncond.mdx +++ b/docs/source/api/pipelines/latent_diffusion_uncond.mdx @@ -1,3 +1,15 @@ + + # Unconditional Latent Diffusion ## Overview diff --git a/docs/source/api/pipelines/overview.mdx b/docs/source/api/pipelines/overview.mdx index 7b2d89e849d6..eed8e0d0b020 100644 --- a/docs/source/api/pipelines/overview.mdx +++ b/docs/source/api/pipelines/overview.mdx @@ -28,7 +28,7 @@ or created independently from each other. To that end, we strive to offer all open-sourced, state-of-the-art diffusion system under a unified API. More specifically, we strive to provide pipelines that -- 1. can load the officially published weights and yield 1-to-1 the same outputs as the original implementation according to the corresponding paper (*e.g.* [LatentDiffusionPipeline](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines/latent_diffusion), uses the officially released weights of [High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752)), +- 1. can load the officially published weights and yield 1-to-1 the same outputs as the original implementation according to the corresponding paper (*e.g.* [LDMTextToImagePipeline](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines/latent_diffusion), uses the officially released weights of [High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752)), - 2. have a simple user interface to run the model in inference (see the [Pipelines API](#pipelines-api) section), - 3. are easy to understand with code that is self-explanatory and can be read along-side the official paper (see [Pipelines summary](#pipelines-summary)), - 4. can easily be contributed by the community (see the [Contribution](#contribution) section). @@ -41,19 +41,33 @@ If you are looking for *official* training examples, please have a look at [exam The following table summarizes all officially supported pipelines, their corresponding paper, and if available a colab notebook to directly try them out. + | Pipeline | Paper | Tasks | Colab |---|---|:---:|:---:| -| [ddpm](./ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation | -| [ddim](./ddim) | [**Denoising Diffusion Implicit Models**](https://arxiv.org/abs/2010.02502) | Unconditional Image Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb) -| [latent_diffusion](./latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Text-to-Image Generation | -| [latent_diffusion_uncond](./latent_diffusion_uncond) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752) | Unconditional Image Generation | -| [pndm](./pndm) | [**Pseudo Numerical Methods for Diffusion Models on Manifolds**](https://arxiv.org/abs/2202.09778) | Unconditional Image Generation | -| [score_sde_ve](./score_sde_ve) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation | -| [score_sde_vp](./score_sde_vp) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation | -| [stable_diffusion](./stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-to-Image Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb) -| [stable_diffusion](./stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb) -| [stable_diffusion](./stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb) -| [stochastic_karras_ve](./stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation | +| [alt_diffusion](./api/pipelines/alt_diffusion) | [**AltDiffusion**](https://arxiv.org/abs/2211.06679) | Image-to-Image Text-Guided Generation | - +| [cycle_diffusion](./api/pipelines/cycle_diffusion) | [**Cycle Diffusion**](https://arxiv.org/abs/2210.05559) | Image-to-Image Text-Guided Generation | +| [dance_diffusion](./api/pipelines/dance_diffusion) | [**Dance Diffusion**](https://github.com/williamberman/diffusers.git) | Unconditional Audio Generation | +| [ddpm](./api/pipelines/ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation | +| [ddim](./api/pipelines/ddim) | [**Denoising Diffusion Implicit Models**](https://arxiv.org/abs/2010.02502) | Unconditional Image Generation | +| [latent_diffusion](./api/pipelines/latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Text-to-Image Generation | +| [latent_diffusion](./api/pipelines/latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Super Resolution Image-to-Image | +| [latent_diffusion_uncond](./api/pipelines/latent_diffusion_uncond) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752) | Unconditional Image Generation | +| [pndm](./api/pipelines/pndm) | [**Pseudo Numerical Methods for Diffusion Models on Manifolds**](https://arxiv.org/abs/2202.09778) | Unconditional Image Generation | +| [score_sde_ve](./api/pipelines/score_sde_ve) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation | +| [score_sde_vp](./api/pipelines/score_sde_vp) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation | +| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-to-Image Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb) +| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb) +| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb) +| [stable_diffusion_2](./api/pipelines/stable_diffusion_2) | [**Stable Diffusion 2**](https://stability.ai/blog/stable-diffusion-v2-release) | Text-to-Image Generation | +| [stable_diffusion_2](./api/pipelines/stable_diffusion_2) | [**Stable Diffusion 2**](https://stability.ai/blog/stable-diffusion-v2-release) | Text-Guided Image Inpainting | +| [stable_diffusion_2](./api/pipelines/stable_diffusion_2) | [**Stable Diffusion 2**](https://stability.ai/blog/stable-diffusion-v2-release) | Text-Guided Super Resolution Image-to-Image | +| [stable_diffusion_safe](./api/pipelines/stable_diffusion_safe) | [**Safe Stable Diffusion**](https://arxiv.org/abs/2211.05105) | Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ml-research/safe-latent-diffusion/blob/main/examples/Safe%20Latent%20Diffusion.ipynb) +| [stochastic_karras_ve](./api/pipelines/stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation | +| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Text-to-Image Generation | +| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Image Variations Generation | +| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Dual Image and Text Guided Generation | +| [vq_diffusion](./api/pipelines/vq_diffusion) | [Vector Quantized Diffusion Model for Text-to-Image Synthesis](https://arxiv.org/abs/2111.14822) | Text-to-Image Generation | + **Note**: Pipelines are simple examples of how to play around with the diffusion systems as described in the corresponding papers. @@ -67,8 +81,8 @@ Diffusion models often consist of multiple independently-trained models or other Each model has been trained independently on a different task and the scheduler can easily be swapped out and replaced with a different one. During inference, we however want to be able to easily load all components and use them in inference - even if one component, *e.g.* CLIP's text encoder, originates from a different library, such as [Transformers](https://github.com/huggingface/transformers). To that end, all pipelines provide the following functionality: -- [`from_pretrained` method](../diffusion_pipeline) that accepts a Hugging Face Hub repository id, *e.g.* [CompVis/stable-diffusion-v1-4](https://huggingface.co/CompVis/stable-diffusion-v1-4) or a path to a local directory, *e.g.* -"./stable-diffusion". To correctly retrieve which models and components should be loaded, one has to provide a `model_index.json` file, *e.g.* [CompVis/stable-diffusion-v1-4/model_index.json](https://huggingface.co/CompVis/stable-diffusion-v1-4/blob/main/model_index.json), which defines all components that should be +- [`from_pretrained` method](../diffusion_pipeline) that accepts a Hugging Face Hub repository id, *e.g.* [runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) or a path to a local directory, *e.g.* +"./stable-diffusion". To correctly retrieve which models and components should be loaded, one has to provide a `model_index.json` file, *e.g.* [runwayml/stable-diffusion-v1-5/model_index.json](https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/model_index.json), which defines all components that should be loaded into the pipelines. More specifically, for each model/component one needs to define the format `: ["", ""]`. `` is the attribute name given to the loaded instance of `` which can be found in the library or pipeline folder called `""`. - [`save_pretrained`](../diffusion_pipeline) that accepts a local path, *e.g.* `./stable-diffusion` under which all models/components of the pipeline will be saved. For each component/model a folder is created inside the local path that is named after the given attribute name, *e.g.* `./stable_diffusion/unet`. In addition, a `model_index.json` file is created at the root of the local path, *e.g.* `./stable_diffusion/model_index.json` so that the complete pipeline can again be instantiated @@ -100,7 +114,7 @@ logic including pre-processing, an unrolled diffusion loop, and post-processing # make sure you're logged in with `huggingface-cli login` from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler -pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4") +pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") pipe = pipe.to("cuda") prompt = "a photo of an astronaut riding a horse on mars" @@ -123,7 +137,7 @@ from diffusers import StableDiffusionImg2ImgPipeline # load the pipeline device = "cuda" pipe = StableDiffusionImg2ImgPipeline.from_pretrained( - "CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16 + "runwayml/stable-diffusion-v1-5", revision="fp16", torch_dtype=torch.float16 ).to(device) # let's download an initial image @@ -151,10 +165,10 @@ You can generate your own latents to reproduce results, or tweak your prompt on The `StableDiffusionInpaintPipeline` lets you edit specific parts of an image by providing a mask and text prompt. ```python -from io import BytesIO - -import requests import PIL +import requests +import torch +from io import BytesIO from diffusers import StableDiffusionInpaintPipeline @@ -170,15 +184,15 @@ mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data init_image = download_image(img_url).resize((512, 512)) mask_image = download_image(mask_url).resize((512, 512)) -device = "cuda" pipe = StableDiffusionInpaintPipeline.from_pretrained( - "CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16 -).to(device) - -prompt = "a cat sitting on a bench" -images = pipe(prompt=prompt, init_image=init_image, mask_image=mask_image, strength=0.75).images + "runwayml/stable-diffusion-inpainting", + revision="fp16", + torch_dtype=torch.float16, +) +pipe = pipe.to("cuda") -images[0].save("cat_on_bench.png") +prompt = "Face of a yellow cat, high resolution, sitting on a park bench" +image = pipe(prompt=prompt, image=init_image, mask_image=mask_image).images[0] ``` You can also run this example on colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb) diff --git a/docs/source/api/pipelines/pndm.mdx b/docs/source/api/pipelines/pndm.mdx index cbb7c5a929c7..89930f4d4f8f 100644 --- a/docs/source/api/pipelines/pndm.mdx +++ b/docs/source/api/pipelines/pndm.mdx @@ -1,3 +1,15 @@ + + # PNDM ## Overview diff --git a/docs/source/api/pipelines/repaint.mdx b/docs/source/api/pipelines/repaint.mdx new file mode 100644 index 000000000000..ce262daffaeb --- /dev/null +++ b/docs/source/api/pipelines/repaint.mdx @@ -0,0 +1,77 @@ + + +# RePaint + +## Overview + +[RePaint: Inpainting using Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2201.09865) (PNDM) by Andreas Lugmayr, Martin Danelljan, Andres Romero, Fisher Yu, Radu Timofte, Luc Van Gool. + +The abstract of the paper is the following: + +Free-form inpainting is the task of adding new content to an image in the regions specified by an arbitrary binary mask. Most existing approaches train for a certain distribution of masks, which limits their generalization capabilities to unseen mask types. Furthermore, training with pixel-wise and perceptual losses often leads to simple textural extensions towards the missing areas instead of semantically meaningful generation. In this work, we propose RePaint: A Denoising Diffusion Probabilistic Model (DDPM) based inpainting approach that is applicable to even extreme masks. We employ a pretrained unconditional DDPM as the generative prior. To condition the generation process, we only alter the reverse diffusion iterations by sampling the unmasked regions using the given image information. Since this technique does not modify or condition the original DDPM network itself, the model produces high-quality and diverse output images for any inpainting form. We validate our method for both faces and general-purpose image inpainting using standard and extreme masks. +RePaint outperforms state-of-the-art Autoregressive, and GAN approaches for at least five out of six mask distributions. + +The original codebase can be found [here](https://github.com/andreas128/RePaint). + +## Available Pipelines: + +| Pipeline | Tasks | Colab +|-------------------------------------------------------------------------------------------------------------------------------|--------------------|:---:| +| [pipeline_repaint.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/repaint/pipeline_repaint.py) | *Image Inpainting* | - | + +## Usage example + +```python +from io import BytesIO + +import torch + +import PIL +import requests +from diffusers import RePaintPipeline, RePaintScheduler + + +def download_image(url): + response = requests.get(url) + return PIL.Image.open(BytesIO(response.content)).convert("RGB") + + +img_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/repaint/celeba_hq_256.png" +mask_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/repaint/mask_256.png" + +# Load the original image and the mask as PIL images +original_image = download_image(img_url).resize((256, 256)) +mask_image = download_image(mask_url).resize((256, 256)) + +# Load the RePaint scheduler and pipeline based on a pretrained DDPM model +scheduler = RePaintScheduler.from_pretrained("google/ddpm-ema-celebahq-256") +pipe = RePaintPipeline.from_pretrained("google/ddpm-ema-celebahq-256", scheduler=scheduler) +pipe = pipe.to("cuda") + +generator = torch.Generator(device="cuda").manual_seed(0) +output = pipe( + original_image=original_image, + mask_image=mask_image, + num_inference_steps=250, + eta=0.0, + jump_length=10, + jump_n_sample=10, + generator=generator, +) +inpainted_image = output.images[0] +``` + +## RePaintPipeline +[[autodoc]] pipelines.repaint.pipeline_repaint.RePaintPipeline + - __call__ + diff --git a/docs/source/api/pipelines/score_sde_ve.mdx b/docs/source/api/pipelines/score_sde_ve.mdx index 2e555914ac9b..3d6619c2591f 100644 --- a/docs/source/api/pipelines/score_sde_ve.mdx +++ b/docs/source/api/pipelines/score_sde_ve.mdx @@ -1,3 +1,15 @@ + + # Score SDE VE ## Overview diff --git a/docs/source/api/pipelines/stable_diffusion.mdx b/docs/source/api/pipelines/stable_diffusion.mdx index 3288c679de95..afa72775f06a 100644 --- a/docs/source/api/pipelines/stable_diffusion.mdx +++ b/docs/source/api/pipelines/stable_diffusion.mdx @@ -1,3 +1,15 @@ + + # Stable diffusion pipelines Stable Diffusion is a text-to-image _latent diffusion_ model created by the researchers and engineers from [CompVis](https://github.com/CompVis), [Stability AI](https://stability.ai/) and [LAION](https://laion.ai/). It's trained on 512x512 images from a subset of the [LAION-5B](https://laion.ai/blog/laion-5b/) dataset. This model uses a frozen CLIP ViT-L/14 text encoder to condition the model on text prompts. With its 860M UNet and 123M text encoder, the model is relatively lightweight and can run on consumer GPUs. @@ -17,6 +29,45 @@ For more details about how Stable Diffusion works and how it differs from the ba | [pipeline_stable_diffusion_img2img.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py) | *Image-to-Image Text-Guided Generation* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb) | [🤗 Diffuse the Rest](https://huggingface.co/spaces/huggingface/diffuse-the-rest) | [pipeline_stable_diffusion_inpaint.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py) | **Experimental** – *Text-Guided Image Inpainting* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb) | Coming soon +## Tips + +### How to load and use different schedulers. + +The stable diffusion pipeline uses [`PNDMScheduler`] scheduler by default. But `diffusers` provides many other schedulers that can be used with the stable diffusion pipeline such as [`DDIMScheduler`], [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`] etc. +To use a different scheduler, you can either change it via the [`ConfigMixin.from_config`] method or pass the `scheduler` argument to the `from_pretrained` method of the pipeline. For example, to use the [`EulerDiscreteScheduler`], you can do the following: + +```python +>>> from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler + +>>> pipeline = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4") +>>> pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config) + +>>> # or +>>> euler_scheduler = EulerDiscreteScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler") +>>> pipeline = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", scheduler=euler_scheduler) +``` + + +### How to convert all use cases with multiple or single pipeline + +If you want to use all possible use cases in a single `DiffusionPipeline` you can either: +- Make use of the [Stable Diffusion Mega Pipeline](https://github.com/huggingface/diffusers/tree/main/examples/community#stable-diffusion-mega) or +- Make use of the `components` functionality to instantiate all components in the most memory-efficient way: + +```python +>>> from diffusers import ( +... StableDiffusionPipeline, +... StableDiffusionImg2ImgPipeline, +... StableDiffusionInpaintPipeline, +... ) + +>>> text2img = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4") +>>> img2img = StableDiffusionImg2ImgPipeline(**text2img.components) +>>> inpaint = StableDiffusionInpaintPipeline(**text2img.components) + +>>> # now you can use text2img(...), img2img(...), inpaint(...) just like the call methods of each respective pipeline +``` + ## StableDiffusionPipelineOutput [[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput @@ -37,3 +88,17 @@ For more details about how Stable Diffusion works and how it differs from the ba - __call__ - enable_attention_slicing - disable_attention_slicing + + +## StableDiffusionImageVariationPipeline +[[autodoc]] StableDiffusionImageVariationPipeline + - __call__ + - enable_attention_slicing + - disable_attention_slicing + + +## StableDiffusionUpscalePipeline +[[autodoc]] StableDiffusionUpscalePipeline + - __call__ + - enable_attention_slicing + - disable_attention_slicing diff --git a/docs/source/api/pipelines/stable_diffusion_2.mdx b/docs/source/api/pipelines/stable_diffusion_2.mdx new file mode 100644 index 000000000000..5df9195034c3 --- /dev/null +++ b/docs/source/api/pipelines/stable_diffusion_2.mdx @@ -0,0 +1,142 @@ + + +# Stable diffusion 2 + +Stable Diffusion 2 is a text-to-image _latent diffusion_ model built upon the work of [Stable Diffusion 1](https://stability.ai/blog/stable-diffusion-public-release). +The project to train Stable Diffusion 2 was led by Robin Rombach and Katherine Crowson from [Stability AI](https://stability.ai/) and [LAION](https://laion.ai/). + +*The Stable Diffusion 2.0 release includes robust text-to-image models trained using a brand new text encoder (OpenCLIP), developed by LAION with support from Stability AI, which greatly improves the quality of the generated images compared to earlier V1 releases. The text-to-image models in this release can generate images with default resolutions of both 512x512 pixels and 768x768 pixels. +These models are trained on an aesthetic subset of the [LAION-5B dataset](https://laion.ai/blog/laion-5b/) created by the DeepFloyd team at Stability AI, which is then further filtered to remove adult content using [LAION’s NSFW filter](https://openreview.net/forum?id=M3Y74vmsMcY).* + +For more details about how Stable Diffusion 2 works and how it differs from Stable Diffusion 1, please refer to the official [launch announcement post](https://stability.ai/blog/stable-diffusion-v2-release). + +## Tips + +### Available checkpoints: + +Note that the architecture is more or less identical to [Stable Diffusion 1](./api/pipelines/stable_diffusion) so please refer to [this page](./api/pipelines/stable_diffusion) for API documentation. + +- *Text-to-Image (512x512 resolution)*: [stabilityai/stable-diffusion-2-base](https://huggingface.co/stabilityai/stable-diffusion-2-base) with [`StableDiffusionPipeline`] +- *Text-to-Image (768x768 resolution)*: [stabilityai/stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) with [`StableDiffusionPipeline`] +- *Image Inpainting (512x512 resolution)*: [stabilityai/stable-diffusion-2-inpainting](https://huggingface.co/stabilityai/stable-diffusion-2-inpainting) with [`StableDiffusionInpaintPipeline`] +- *Image Upscaling (x4 resolution resolution)*: [stable-diffusion-x4-upscaler](https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler) [`StableDiffusionUpscalePipeline`] + +We recommend using the [`DPMSolverMultistepScheduler`] as it's currently the fastest scheduler there is. + +- *Text-to-Image (512x512 resolution)*: + +```python +from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler +import torch + +repo_id = "stabilityai/stable-diffusion-2-base" +pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch.float16, revision="fp16") + +pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) +pipe = pipe.to("cuda") + +prompt = "High quality photo of an astronaut riding a horse in space" +image = pipe(prompt, num_inference_steps=25).images[0] +image.save("astronaut.png") +``` + +- *Text-to-Image (768x768 resolution)*: + +```python +from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler +import torch + +repo_id = "stabilityai/stable-diffusion-2" +pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch.float16, revision="fp16") + +pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) +pipe = pipe.to("cuda") + +prompt = "High quality photo of an astronaut riding a horse in space" +image = pipe(prompt, guidance_scale=9, num_inference_steps=25).images[0] +image.save("astronaut.png") +``` + +- *Image Inpainting (512x512 resolution)*: + +```python +import PIL +import requests +import torch +from io import BytesIO + +from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler + + +def download_image(url): + response = requests.get(url) + return PIL.Image.open(BytesIO(response.content)).convert("RGB") + + +img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" +mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" + +init_image = download_image(img_url).resize((512, 512)) +mask_image = download_image(mask_url).resize((512, 512)) + +repo_id = "stabilityai/stable-diffusion-2-inpainting" +pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch.float16, revision="fp16") + +pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) +pipe = pipe.to("cuda") + +prompt = "Face of a yellow cat, high resolution, sitting on a park bench" +image = pipe(prompt=prompt, image=init_image, mask_image=mask_image, num_inference_steps=25).images[0] + +image.save("yellow_cat.png") +``` + +- *Image Upscaling (x4 resolution resolution)*: [stable-diffusion-x4-upscaler](https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler) [`StableDiffusionUpscalePipeline`] + +```python +import requests +from PIL import Image +from io import BytesIO +from diffusers import StableDiffusionUpscalePipeline +import torch + +# load model and scheduler +model_id = "stabilityai/stable-diffusion-x4-upscaler" +pipeline = StableDiffusionUpscalePipeline.from_pretrained(model_id, revision="fp16", torch_dtype=torch.float16) +pipeline = pipeline.to("cuda") + +# let's download an image +url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-upscale/low_res_cat.png" +response = requests.get(url) +low_res_img = Image.open(BytesIO(response.content)).convert("RGB") +low_res_img = low_res_img.resize((128, 128)) +prompt = "a white cat" +upscaled_image = pipeline(prompt=prompt, image=low_res_img).images[0] +upscaled_image.save("upsampled_cat.png") +``` + +### How to load and use different schedulers. + +The stable diffusion pipeline uses [`DDIMScheduler`] scheduler by default. But `diffusers` provides many other schedulers that can be used with the stable diffusion pipeline such as [`PNDMScheduler`], [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`] etc. +To use a different scheduler, you can either change it via the [`ConfigMixin.from_config`] method or pass the `scheduler` argument to the `from_pretrained` method of the pipeline. For example, to use the [`EulerDiscreteScheduler`], you can do the following: + +```python +>>> from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler + +>>> pipeline = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2") +>>> pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config) + +>>> # or +>>> euler_scheduler = EulerDiscreteScheduler.from_pretrained("stabilityai/stable-diffusion-2", subfolder="scheduler") +>>> pipeline = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2", scheduler=euler_scheduler) +``` diff --git a/docs/source/api/pipelines/stable_diffusion_safe.mdx b/docs/source/api/pipelines/stable_diffusion_safe.mdx new file mode 100644 index 000000000000..81fc59d3928c --- /dev/null +++ b/docs/source/api/pipelines/stable_diffusion_safe.mdx @@ -0,0 +1,90 @@ + + +# Safe Stable Diffusion + +Safe Stable Diffusion was proposed in [Safe Latent Diffusion: Mitigating Inappropriate Degeneration in Diffusion Models](https://arxiv.org/abs/2211.05105) and mitigates the well known issue that models like Stable Diffusion that are trained on unfiltered, web-crawled datasets tend to suffer from inappropriate degeneration. For instance Stable Diffusion may unexpectedly generate nudity, violence, images depicting self-harm, or otherwise offensive content. +Safe Stable Diffusion is an extension to the Stable Diffusion that drastically reduces content like this. + +The abstract of the paper is the following: + +*Text-conditioned image generation models have recently achieved astonishing results in image quality and text alignment and are consequently employed in a fast-growing number of applications. Since they are highly data-driven, relying on billion-sized datasets randomly scraped from the internet, they also suffer, as we demonstrate, from degenerated and biased human behavior. In turn, they may even reinforce such biases. To help combat these undesired side effects, we present safe latent diffusion (SLD). Specifically, to measure the inappropriate degeneration due to unfiltered and imbalanced training sets, we establish a novel image generation test bed-inappropriate image prompts (I2P)-containing dedicated, real-world image-to-text prompts covering concepts such as nudity and violence. As our exhaustive empirical evaluation demonstrates, the introduced SLD removes and suppresses inappropriate image parts during the diffusion process, with no additional training required and no adverse effect on overall image quality or text alignment.* + + +*Overview*: + +| Pipeline | Tasks | Colab | Demo +|---|---|:---:|:---:| +| [pipeline_stable_diffusion_safe.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py) | *Text-to-Image Generation* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ml-research/safe-latent-diffusion/blob/main/examples/Safe%20Latent%20Diffusion.ipynb) | - + +## Tips + +- Safe Stable Diffusion may also be used with weights of [Stable Diffusion](./api/pipelines/stable_diffusion). + +### Run Safe Stable Diffusion + +Safe Stable Diffusion can be tested very easily with the [`StableDiffusionPipelineSafe`], and the `"AIML-TUDA/stable-diffusion-safe"` checkpoint exactly in the same way it is shown in the [Conditional Image Generation Guide](./using-diffusers/conditional_image_generation). + +### Interacting with the Safety Concept + +To check and edit the currently used safety concept, use the `safety_concept` property of [`StableDiffusionPipelineSafe`] +```python +>>> from diffusers import StableDiffusionPipelineSafe + +>>> pipeline = StableDiffusionPipelineSafe.from_pretrained("AIML-TUDA/stable-diffusion-safe") +>>> pipeline.safety_concept +``` +For each image generation the active concept is also contained in [`StableDiffusionSafePipelineOutput`]. + +### Using pre-defined safety configurations + +You may use the 4 configurations defined in the [Safe Latent Diffusion paper](https://arxiv.org/abs/2211.05105) as follows: + +```python +>>> from diffusers import StableDiffusionPipelineSafe +>>> from diffusers.pipelines.stable_diffusion_safe import SafetyConfig + +>>> pipeline = StableDiffusionPipelineSafe.from_pretrained("AIML-TUDA/stable-diffusion-safe") +>>> prompt = "the four horsewomen of the apocalypse, painting by tom of finland, gaston bussiere, craig mullins, j. c. leyendecker" +>>> out = pipeline(prompt=prompt, **SafetyConfig.MAX) +``` + +The following configurations are available: `SafetyConfig.WEAK`, `SafetyConfig.MEDIUM`, `SafetyConfig.STRONg`, and `SafetyConfig.MAX`. + +### How to load and use different schedulers. + +The safe stable diffusion pipeline uses [`PNDMScheduler`] scheduler by default. But `diffusers` provides many other schedulers that can be used with the stable diffusion pipeline such as [`DDIMScheduler`], [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`] etc. +To use a different scheduler, you can either change it via the [`ConfigMixin.from_config`] method or pass the `scheduler` argument to the `from_pretrained` method of the pipeline. For example, to use the [`EulerDiscreteScheduler`], you can do the following: + +```python +>>> from diffusers import StableDiffusionPipelineSafe, EulerDiscreteScheduler + +>>> pipeline = StableDiffusionPipelineSafe.from_pretrained("AIML-TUDA/stable-diffusion-safe") +>>> pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config) + +>>> # or +>>> euler_scheduler = EulerDiscreteScheduler.from_pretrained("AIML-TUDA/stable-diffusion-safe", subfolder="scheduler") +>>> pipeline = StableDiffusionPipelineSafe.from_pretrained( +... "AIML-TUDA/stable-diffusion-safe", scheduler=euler_scheduler +... ) +``` + + +## StableDiffusionSafePipelineOutput +[[autodoc]] pipelines.stable_diffusion_safe.StableDiffusionSafePipelineOutput + +## StableDiffusionPipelineSafe +[[autodoc]] StableDiffusionPipelineSafe + - __call__ + - enable_attention_slicing + - disable_attention_slicing + diff --git a/docs/source/api/pipelines/stochastic_karras_ve.mdx b/docs/source/api/pipelines/stochastic_karras_ve.mdx index d1ef8f7b3f6c..de762cbda002 100644 --- a/docs/source/api/pipelines/stochastic_karras_ve.mdx +++ b/docs/source/api/pipelines/stochastic_karras_ve.mdx @@ -1,3 +1,15 @@ + + # Stochastic Karras VE ## Overview diff --git a/docs/source/api/pipelines/versatile_diffusion.mdx b/docs/source/api/pipelines/versatile_diffusion.mdx new file mode 100644 index 000000000000..f557c5b0aac8 --- /dev/null +++ b/docs/source/api/pipelines/versatile_diffusion.mdx @@ -0,0 +1,73 @@ + + +# VersatileDiffusion + +VersatileDiffusion was proposed in [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) by Xingqian Xu, Zhangyang Wang, Eric Zhang, Kai Wang, Humphrey Shi . + +The abstract of the paper is the following: + +*The recent advances in diffusion models have set an impressive milestone in many generation tasks. Trending works such as DALL-E2, Imagen, and Stable Diffusion have attracted great interest in academia and industry. Despite the rapid landscape changes, recent new approaches focus on extensions and performance rather than capacity, thus requiring separate models for separate tasks. In this work, we expand the existing single-flow diffusion pipeline into a multi-flow network, dubbed Versatile Diffusion (VD), that handles text-to-image, image-to-text, image-variation, and text-variation in one unified model. Moreover, we generalize VD to a unified multi-flow multimodal diffusion framework with grouped layers, swappable streams, and other propositions that can process modalities beyond images and text. Through our experiments, we demonstrate that VD and its underlying framework have the following merits: a) VD handles all subtasks with competitive quality; b) VD initiates novel extensions and applications such as disentanglement of style and semantic, image-text dual-guided generation, etc.; c) Through these experiments and applications, VD provides more semantic insights of the generated outputs.* + +## Tips + +- VersatileDiffusion is conceptually very similar as [Stable Diffusion](./api/pipelines/stable_diffusion), but instead of providing just a image data stream conditioned on text, VersatileDiffusion provides both a image and text data stream and can be conditioned on both text and image. + +### *Run VersatileDiffusion* + +You can both load the memory intensive "all-in-one" [`VersatileDiffusionPipeline`] that can run all tasks +with the same class as shown in [`VersatileDiffusionPipeline.text_to_image`], [`VersatileDiffusionPipeline.image_variation`], and [`VersatileDiffusionPipeline.dual_guided`] + +**or** + +You can run the individual pipelines which are much more memory efficient: + +- *Text-to-Image*: [`VersatileDiffusionTextToImagePipeline.__call__`] +- *Image Variation*: [`VersatileDiffusionImageVariationPipeline.__call__`] +- *Dual Text and Image Guided Generation*: [`VersatileDiffusionDualGuidedPipeline.__call__`] + +### *How to load and use different schedulers.* + +The versatile diffusion pipelines uses [`DDIMScheduler`] scheduler by default. But `diffusers` provides many other schedulers that can be used with the alt diffusion pipeline such as [`PNDMScheduler`], [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`] etc. +To use a different scheduler, you can either change it via the [`ConfigMixin.from_config`] method or pass the `scheduler` argument to the `from_pretrained` method of the pipeline. For example, to use the [`EulerDiscreteScheduler`], you can do the following: + +```python +>>> from diffusers import VersatileDiffusionPipeline, EulerDiscreteScheduler + +>>> pipeline = VersatileDiffusionPipeline.from_pretrained("shi-labs/versatile-diffusion") +>>> pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config) + +>>> # or +>>> euler_scheduler = EulerDiscreteScheduler.from_pretrained("shi-labs/versatile-diffusion", subfolder="scheduler") +>>> pipeline = VersatileDiffusionPipeline.from_pretrained("shi-labs/versatile-diffusion", scheduler=euler_scheduler) +``` + +## VersatileDiffusionPipeline +[[autodoc]] VersatileDiffusionPipeline + +## VersatileDiffusionTextToImagePipeline +[[autodoc]] VersatileDiffusionTextToImagePipeline + - __call__ + - enable_attention_slicing + - disable_attention_slicing + +## VersatileDiffusionImageVariationPipeline +[[autodoc]] VersatileDiffusionImageVariationPipeline + - __call__ + - enable_attention_slicing + - disable_attention_slicing + +## VersatileDiffusionDualGuidedPipeline +[[autodoc]] VersatileDiffusionDualGuidedPipeline + - __call__ + - enable_attention_slicing + - disable_attention_slicing diff --git a/docs/source/api/pipelines/vq_diffusion.mdx b/docs/source/api/pipelines/vq_diffusion.mdx new file mode 100644 index 000000000000..92cc903eee79 --- /dev/null +++ b/docs/source/api/pipelines/vq_diffusion.mdx @@ -0,0 +1,34 @@ + + +# VQDiffusion + +## Overview + +[Vector Quantized Diffusion Model for Text-to-Image Synthesis](https://arxiv.org/abs/2111.14822) by Shuyang Gu, Dong Chen, Jianmin Bao, Fang Wen, Bo Zhang, Dongdong Chen, Lu Yuan, Baining Guo + +The abstract of the paper is the following: + +We present the vector quantized diffusion (VQ-Diffusion) model for text-to-image generation. This method is based on a vector quantized variational autoencoder (VQ-VAE) whose latent space is modeled by a conditional variant of the recently developed Denoising Diffusion Probabilistic Model (DDPM). We find that this latent-space method is well-suited for text-to-image generation tasks because it not only eliminates the unidirectional bias with existing methods but also allows us to incorporate a mask-and-replace diffusion strategy to avoid the accumulation of errors, which is a serious problem with existing methods. Our experiments show that the VQ-Diffusion produces significantly better text-to-image generation results when compared with conventional autoregressive (AR) models with similar numbers of parameters. Compared with previous GAN-based text-to-image methods, our VQ-Diffusion can handle more complex scenes and improve the synthesized image quality by a large margin. Finally, we show that the image generation computation in our method can be made highly efficient by reparameterization. With traditional AR methods, the text-to-image generation time increases linearly with the output image resolution and hence is quite time consuming even for normal size images. The VQ-Diffusion allows us to achieve a better trade-off between quality and speed. Our experiments indicate that the VQ-Diffusion model with the reparameterization is fifteen times faster than traditional AR methods while achieving a better image quality. + +The original codebase can be found [here](https://github.com/microsoft/VQ-Diffusion). + +## Available Pipelines: + +| Pipeline | Tasks | Colab +|---|---|:---:| +| [pipeline_vq_diffusion.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py) | *Text-to-Image Generation* | - | + + +## VQDiffusionPipeline +[[autodoc]] pipelines.vq_diffusion.pipeline_vq_diffusion.VQDiffusionPipeline + - __call__ diff --git a/docs/source/api/schedulers.mdx b/docs/source/api/schedulers.mdx index 12a6b5c587bc..7ed527bedf3f 100644 --- a/docs/source/api/schedulers.mdx +++ b/docs/source/api/schedulers.mdx @@ -16,7 +16,7 @@ Diffusers contains multiple pre-built schedule functions for the diffusion proce ## What is a scheduler? -The schedule functions, denoted *Schedulers* in the library take in the output of a trained model, a sample which the diffusion process is iterating on, and a timestep to return a denoised sample. +The schedule functions, denoted *Schedulers* in the library take in the output of a trained model, a sample which the diffusion process is iterating on, and a timestep to return a denoised sample. That's why schedulers may also be called *Samplers* in other diffusion models implementations. - Schedulers define the methodology for iteratively adding noise to an image or for updating a sample based on model outputs. - adding noise in different manners represent the algorithmic processes to train a diffusion model by adding noise to images. @@ -70,6 +70,12 @@ Original paper can be found [here](https://arxiv.org/abs/2010.02502). [[autodoc]] DDPMScheduler +#### Multistep DPM-Solver + +Original paper can be found [here](https://arxiv.org/abs/2206.00927) and the [improved version](https://arxiv.org/abs/2211.01095). The original implementation can be found [here](https://github.com/LuChengTHU/dpm-solver). + +[[autodoc]] DPMSolverMultistepScheduler + #### Variance exploding, stochastic sampling from Karras et. al Original paper can be found [here](https://arxiv.org/abs/2006.11239). @@ -89,13 +95,19 @@ Original implementation can be found [here](https://github.com/crowsonkb/k-diffu [[autodoc]] PNDMScheduler -#### variance exploding stochastic differential equation (SDE) scheduler +#### variance exploding stochastic differential equation (VE-SDE) scheduler Original paper can be found [here](https://arxiv.org/abs/2011.13456). [[autodoc]] ScoreSdeVeScheduler -#### variance preserving stochastic differential equation (SDE) scheduler +#### improved pseudo numerical methods for diffusion models (iPNDM) + +Original implementation can be found [here](https://github.com/crowsonkb/v-diffusion-pytorch/blob/987f8985e38208345c1959b0ea767a625831cc9b/diffusion/sampling.py#L296). + +[[autodoc]] IPNDMScheduler + +#### variance preserving stochastic differential equation (VP-SDE) scheduler Original paper can be found [here](https://arxiv.org/abs/2011.13456). @@ -106,3 +118,34 @@ Score SDE-VP is under construction. [[autodoc]] schedulers.scheduling_sde_vp.ScoreSdeVpScheduler + +#### Euler scheduler + +Euler scheduler (Algorithm 2) from the paper [Elucidating the Design Space of Diffusion-Based Generative Models](https://arxiv.org/abs/2206.00364) by Karras et al. (2022). Based on the original [k-diffusion](https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L51) implementation by Katherine Crowson. +Fast scheduler which often times generates good outputs with 20-30 steps. + +[[autodoc]] EulerDiscreteScheduler + + +#### Euler Ancestral scheduler + +Ancestral sampling with Euler method steps. Based on the original (k-diffusion)[https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L72] implementation by Katherine Crowson. +Fast scheduler which often times generates good outputs with 20-30 steps. + +[[autodoc]] EulerAncestralDiscreteScheduler + + +#### VQDiffusionScheduler + +Original paper can be found [here](https://arxiv.org/abs/2111.14822) + +[[autodoc]] VQDiffusionScheduler + +#### RePaint scheduler + +DDPM-based inpainting scheduler for unsupervised inpainting with extreme masks. +Intended for use with [`RePaintPipeline`]. +Based on the paper [RePaint: Inpainting using Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2201.09865) +and the original implementation by Andreas Lugmayr et al.: https://github.com/andreas128/RePaint + +[[autodoc]] RePaintScheduler diff --git a/docs/source/conceptual/stable_diffusion.mdx b/docs/source/conceptual/stable_diffusion.mdx index c00359d2ac1a..3e87674aa83a 100644 --- a/docs/source/conceptual/stable_diffusion.mdx +++ b/docs/source/conceptual/stable_diffusion.mdx @@ -12,6 +12,4 @@ specific language governing permissions and limitations under the License. # Stable Diffusion -Under construction 🚧 - -For now please visit this [very in-detail blog post](https://huggingface.co/blog/stable_diffusion) +Please visit this [very in-detail blog post](https://huggingface.co/blog/stable_diffusion) on Stable Diffusion! diff --git a/docs/source/imgs/access_request.png b/docs/source/imgs/access_request.png new file mode 100644 index 000000000000..33c6abc88dfb Binary files /dev/null and b/docs/source/imgs/access_request.png differ diff --git a/docs/source/index.mdx b/docs/source/index.mdx index 392b22399908..975ff47b61e6 100644 --- a/docs/source/index.mdx +++ b/docs/source/index.mdx @@ -34,9 +34,13 @@ available a colab notebook to directly try them out. | Pipeline | Paper | Tasks | Colab |---|---|:---:|:---:| +| [alt_diffusion](./api/pipelines/alt_diffusion) | [**AltDiffusion**](https://arxiv.org/abs/2211.06679) | Image-to-Image Text-Guided Generation | +| [cycle_diffusion](./api/pipelines/cycle_diffusion) | [**Cycle Diffusion**](https://arxiv.org/abs/2210.05559) | Image-to-Image Text-Guided Generation | +| [dance_diffusion](./api/pipelines/dance_diffusion) | [**Dance Diffusion**](https://github.com/williamberman/diffusers.git) | Unconditional Audio Generation | | [ddpm](./api/pipelines/ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation | | [ddim](./api/pipelines/ddim) | [**Denoising Diffusion Implicit Models**](https://arxiv.org/abs/2010.02502) | Unconditional Image Generation | | [latent_diffusion](./api/pipelines/latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Text-to-Image Generation | +| [latent_diffusion](./api/pipelines/latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Super Resolution Image-to-Image | | [latent_diffusion_uncond](./api/pipelines/latent_diffusion_uncond) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752) | Unconditional Image Generation | | [pndm](./api/pipelines/pndm) | [**Pseudo Numerical Methods for Diffusion Models on Manifolds**](https://arxiv.org/abs/2202.09778) | Unconditional Image Generation | | [score_sde_ve](./api/pipelines/score_sde_ve) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation | @@ -44,6 +48,14 @@ available a colab notebook to directly try them out. | [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-to-Image Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb) | [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb) | [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb) +| [stable_diffusion_2](./api/pipelines/stable_diffusion_2) | [**Stable Diffusion 2**](https://stability.ai/blog/stable-diffusion-v2-release) | Text-to-Image Generation | +| [stable_diffusion_2](./api/pipelines/stable_diffusion_2) | [**Stable Diffusion 2**](https://stability.ai/blog/stable-diffusion-v2-release) | Text-Guided Image Inpainting | +| [stable_diffusion_2](./api/pipelines/stable_diffusion_2) | [**Stable Diffusion 2**](https://stability.ai/blog/stable-diffusion-v2-release) | Text-Guided Super Resolution Image-to-Image | +| [stable_diffusion_safe](./api/pipelines/stable_diffusion_safe) | [**Safe Stable Diffusion**](https://arxiv.org/abs/2211.05105) | Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ml-research/safe-latent-diffusion/blob/main/examples/Safe%20Latent%20Diffusion.ipynb) | [stochastic_karras_ve](./api/pipelines/stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation | +| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Text-to-Image Generation | +| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Image Variations Generation | +| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Dual Image and Text Guided Generation | +| [vq_diffusion](./api/pipelines/vq_diffusion) | [Vector Quantized Diffusion Model for Text-to-Image Synthesis](https://arxiv.org/abs/2111.14822) | Text-to-Image Generation | **Note**: Pipelines are simple examples of how to play around with the diffusion systems as described in the corresponding papers. diff --git a/docs/source/installation.mdx b/docs/source/installation.mdx index fa9466b6d1aa..9c93b359560e 100644 --- a/docs/source/installation.mdx +++ b/docs/source/installation.mdx @@ -12,9 +12,12 @@ specific language governing permissions and limitations under the License. # Installation -Install Diffusers for with PyTorch. Support for other libraries will come in the future +Install 🤗 Diffusers for whichever deep learning library you’re working with. -🤗 Diffusers is tested on Python 3.7+, and PyTorch 1.7.0+. +🤗 Diffusers is tested on Python 3.7+, PyTorch 1.7.0+ and flax. Follow the installation instructions below for the deep learning library you are using: + +- [PyTorch](https://pytorch.org/get-started/locally/) installation instructions. +- [Flax](https://flax.readthedocs.io/en/latest/) installation instructions. ## Install with pip @@ -36,12 +39,30 @@ source .env/bin/activate Now you're ready to install 🤗 Diffusers with the following command: +**For PyTorch** + +```bash +pip install diffusers["torch"] +``` + +**For Flax** + ```bash -pip install diffusers +pip install diffusers["flax"] ``` ## Install from source +Before intsalling `diffusers` from source, make sure you have `torch` and `accelerate` installed. + +For `torch` installation refer to the `torch` [docs](https://pytorch.org/get-started/locally/#start-locally). + +To install `accelerate` + +```bash +pip install accelerate +``` + Install 🤗 Diffusers from source with the following command: ```bash @@ -53,7 +74,7 @@ The `main` version is useful for staying up-to-date with the latest developments For instance, if a bug has been fixed since the last official release but a new release hasn't been rolled out yet. However, this means the `main` version may not always be stable. We strive to keep the `main` version operational, and most issues are usually resolved within a few hours or a day. -If you run into a problem, please open an [Issue](https://github.com/huggingface/transformers/issues) so we can fix it even sooner! +If you run into a problem, please open an [Issue](https://github.com/huggingface/transformers/issues), so we can fix it even sooner! ## Editable install @@ -67,7 +88,18 @@ Clone the repository and install 🤗 Diffusers with the following commands: ```bash git clone https://github.com/huggingface/diffusers.git cd diffusers -pip install -e . +``` + +**For PyTorch** + +``` +pip install -e ".[torch]" +``` + +**For Flax** + +``` +pip install -e ".[flax]" ``` These commands will link the folder you cloned the repository to and your Python library paths. diff --git a/docs/source/optimization/fp16.mdx b/docs/source/optimization/fp16.mdx index b561aedbfe4a..4371daacc903 100644 --- a/docs/source/optimization/fp16.mdx +++ b/docs/source/optimization/fp16.mdx @@ -14,17 +14,21 @@ specific language governing permissions and limitations under the License. We present some techniques and ideas to optimize 🤗 Diffusers _inference_ for memory or speed. - | | Latency | Speedup | -|------------------|---------|---------| +| ---------------- | ------- | ------- | | original | 9.50s | x1 | | cuDNN auto-tuner | 9.37s | x1.01 | -| autocast (fp16) | 5.47s | x1.91 | -| fp16 | 3.61s | x2.91 | -| channels last | 3.30s | x2.87 | +| autocast (fp16) | 5.47s | x1.74 | +| fp16 | 3.61s | x2.63 | +| channels last | 3.30s | x2.88 | | traced UNet | 3.21s | x2.96 | +| memory efficient attention | 2.63s | x3.61 | -obtained on NVIDIA TITAN RTX by generating a single image of size 512x512 from the prompt "a photo of an astronaut riding a horse on mars" with 50 DDIM steps. + + obtained on NVIDIA TITAN RTX by generating a single image of size 512x512 from + the prompt "a photo of an astronaut riding a horse on mars" with 50 DDIM + steps. + ## Enable cuDNN auto-tuner @@ -56,12 +60,12 @@ If you use a CUDA GPU, you can take advantage of `torch.autocast` to perform inf from torch import autocast from diffusers import StableDiffusionPipeline -pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4") +pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") pipe = pipe.to("cuda") prompt = "a photo of an astronaut riding a horse on mars" with autocast("cuda"): - image = pipe(prompt).images[0] + image = pipe(prompt).images[0] ``` Despite the precision loss, in our experience the final image results look the same as the `float32` versions. Feel free to experiment and report back! @@ -72,14 +76,14 @@ To save more GPU memory and get even more speed, you can load and run the model ```Python pipe = StableDiffusionPipeline.from_pretrained( - "CompVis/stable-diffusion-v1-4", + "runwayml/stable-diffusion-v1-5", revision="fp16", torch_dtype=torch.float16, ) pipe = pipe.to("cuda") prompt = "a photo of an astronaut riding a horse on mars" -image = pipe(prompt).images[0] +image = pipe(prompt).images[0] ``` ## Sliced attention for additional memory savings @@ -87,7 +91,10 @@ image = pipe(prompt).images[0] For even additional memory savings, you can use a sliced version of attention that performs the computation in steps instead of all at once. -Attention slicing is useful even if a batch size of just 1 is used - as long as the model uses more than one attention head. If there is more than one attention head the *QK^T* attention matrix can be computed sequentially for each head which can save a significant amount of memory. + Attention slicing is useful even if a batch size of just 1 is used - as long + as the model uses more than one attention head. If there is more than one + attention head the *QK^T* attention matrix can be computed sequentially for + each head which can save a significant amount of memory. To perform the attention computation sequentially over each head, you only need to invoke [`~StableDiffusionPipeline.enable_attention_slicing`] in your pipeline before inference, like here: @@ -97,7 +104,7 @@ import torch from diffusers import StableDiffusionPipeline pipe = StableDiffusionPipeline.from_pretrained( - "CompVis/stable-diffusion-v1-4", + "runwayml/stable-diffusion-v1-5", revision="fp16", torch_dtype=torch.float16, ) @@ -105,11 +112,55 @@ pipe = pipe.to("cuda") prompt = "a photo of an astronaut riding a horse on mars" pipe.enable_attention_slicing() -image = pipe(prompt).images[0] +image = pipe(prompt).images[0] ``` There's a small performance penalty of about 10% slower inference times, but this method allows you to use Stable Diffusion in as little as 3.2 GB of VRAM! +## Offloading to CPU with accelerate for memory savings + +For additional memory savings, you can offload the weights to CPU and load them to GPU when performing the forward pass. + +To perform CPU offloading, all you have to do is invoke [`~StableDiffusionPipeline.enable_sequential_cpu_offload`]: + +```Python +import torch +from diffusers import StableDiffusionPipeline + +pipe = StableDiffusionPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", + revision="fp16", + torch_dtype=torch.float16, +) +pipe = pipe.to("cuda") + +prompt = "a photo of an astronaut riding a horse on mars" +pipe.enable_sequential_cpu_offload() +image = pipe(prompt).images[0] +``` + +And you can get the memory consumption to < 2GB. + +If is also possible to chain it with attention slicing for minimal memory consumption, running it in as little as < 800mb of GPU vRAM: + +```Python +import torch +from diffusers import StableDiffusionPipeline + +pipe = StableDiffusionPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", + revision="fp16", + torch_dtype=torch.float16, +) +pipe = pipe.to("cuda") + +prompt = "a photo of an astronaut riding a horse on mars" +pipe.enable_sequential_cpu_offload() +pipe.enable_attention_slicing(1) + +image = pipe(prompt).images[0] +``` + ## Using Channels Last memory format Channels last memory format is an alternative way of ordering NCHW tensors in memory preserving dimensions ordering. Channels last tensors ordered in such a way that channels become the densest dimension (aka storing images pixel-per-pixel). Since not all operators currently support channels last format it may result in a worst performance, so it's better to try it and see if it works for your model. @@ -152,7 +203,7 @@ def generate_inputs(): pipe = StableDiffusionPipeline.from_pretrained( - "CompVis/stable-diffusion-v1-4", + "runwayml/stable-diffusion-v1-5", revision="fp16", torch_dtype=torch.float16, ).to("cuda") @@ -216,7 +267,7 @@ class UNet2DConditionOutput: pipe = StableDiffusionPipeline.from_pretrained( - "CompVis/stable-diffusion-v1-4", + "runwayml/stable-diffusion-v1-5", revision="fp16", torch_dtype=torch.float16, ).to("cuda") @@ -240,3 +291,41 @@ pipe.unet = TracedUNet() with torch.inference_mode(): image = pipe([prompt] * 1, num_inference_steps=50).images[0] ``` + + +## Memory Efficient Attention +Recent work on optimizing the bandwitdh in the attention block have generated huge speed ups and gains in GPU memory usage. The most recent being Flash Attention (from @tridao, [code](https://github.com/HazyResearch/flash-attention), [paper](https://arxiv.org/pdf/2205.14135.pdf)) . +Here are the speedups we obtain on a few Nvidia GPUs when running the inference at 512x512 with a batch size of 1 (one prompt): + +| GPU | Base Attention FP16 | Memory Efficient Attention FP16 | +|------------------ |--------------------- |--------------------------------- | +| NVIDIA Tesla T4 | 3.5it/s | 5.5it/s | +| NVIDIA 3060 RTX | 4.6it/s | 7.8it/s | +| NVIDIA A10G | 8.88it/s | 15.6it/s | +| NVIDIA RTX A6000 | 11.7it/s | 21.09it/s | +| NVIDIA TITAN RTX | 12.51it/s | 18.22it/s | +| A100-SXM4-40GB | 18.6it/s | 29.it/s | +| A100-SXM-80GB | 18.7it/s | 29.5it/s | + +To leverage it just make sure you have: + - PyTorch > 1.12 + - Cuda available + - Installed the [xformers](https://github.com/facebookresearch/xformers) library +```python +from diffusers import StableDiffusionPipeline +import torch + +pipe = StableDiffusionPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", + revision="fp16", + torch_dtype=torch.float16, +).to("cuda") + +pipe.enable_xformers_memory_efficient_attention() + +with torch.inference_mode(): + sample = pipe("a small cat") + +# optional: You can disable it via +# pipe.disable_xformers_memory_efficient_attention() +``` \ No newline at end of file diff --git a/docs/source/optimization/mps.mdx b/docs/source/optimization/mps.mdx index ff9d614c870f..8a2d5ad763a2 100644 --- a/docs/source/optimization/mps.mdx +++ b/docs/source/optimization/mps.mdx @@ -17,9 +17,10 @@ specific language governing permissions and limitations under the License. ## Requirements - Mac computer with Apple silicon (M1/M2) hardware. -- macOS 12.3 or later. +- macOS 12.6 or later (13.0 or later recommended). - arm64 version of Python. -- PyTorch [Preview (Nightly)](https://pytorch.org/get-started/locally/), version `1.14.0.dev20221007` or later. +- PyTorch 1.13. You can install it with `pip` or `conda` using the instructions in https://pytorch.org/get-started/locally/. + ## Inference Pipeline @@ -31,9 +32,12 @@ We recommend to "prime" the pipeline using an additional one-time pass through i # make sure you're logged in with `huggingface-cli login` from diffusers import StableDiffusionPipeline -pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4") +pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") pipe = pipe.to("mps") +# Recommended if your computer has < 64 GB of RAM +pipe.enable_attention_slicing() + prompt = "a photo of an astronaut riding a horse on mars" # First-time "warmup" pass (see explanation above) @@ -43,16 +47,17 @@ _ = pipe(prompt, num_inference_steps=1) image = pipe(prompt).images[0] ``` -## Known Issues +## Performance Recommendations -- As mentioned above, we are investigating a strange [first-time inference issue](https://github.com/huggingface/diffusers/issues/372). -- Generating multiple prompts in a batch [crashes or doesn't work reliably](https://github.com/huggingface/diffusers/issues/363). We believe this might be related to the [`mps` backend in PyTorch](https://github.com/pytorch/pytorch/issues/84039#issuecomment-1237735249), but we need to investigate in more depth. For now, we recommend to iterate instead of batching. +M1/M2 performance is very sensitive to memory pressure. The system will automatically swap if it needs to, but performance will degrade significantly when it does. -## Performance +We recommend you use _attention slicing_ to reduce memory pressure during inference and prevent swapping, particularly if your computer has lass than 64 GB of system RAM, or if you generate images at non-standard resolutions larger than 512 × 512 pixels. Attention slicing performs the costly attention operation in multiple steps instead of all at once. It usually has a performance impact of ~20% in computers without universal memory, but we have observed _better performance_ in most Apple Silicon computers, unless you have 64 GB or more. -These are the results we got on a M1 Max MacBook Pro with 64 GB of RAM, running macOS Ventura Version 13.0 Beta (22A5331f). We performed Stable Diffusion text-to-image generation of the same prompt for 50 inference steps, using a guidance scale of 7.5. +```python +pipeline.enable_attention_slicing() +``` + +## Known Issues -| Device | Steps | Time | -|--------|-------|---------| -| CPU | 50 | 213.46s | -| MPS | 50 | 30.81s | \ No newline at end of file +- As mentioned above, we are investigating a strange [first-time inference issue](https://github.com/huggingface/diffusers/issues/372). +- Generating multiple prompts in a batch [crashes or doesn't work reliably](https://github.com/huggingface/diffusers/issues/363). We believe this is related to the [`mps` backend in PyTorch](https://github.com/pytorch/pytorch/issues/84039). This is being resolved, but for now we recommend to iterate instead of batching. diff --git a/docs/source/optimization/onnx.mdx b/docs/source/optimization/onnx.mdx index 9bbc4f2077c2..e79efbde0742 100644 --- a/docs/source/optimization/onnx.mdx +++ b/docs/source/optimization/onnx.mdx @@ -28,7 +28,7 @@ The snippet below demonstrates how to use the ONNX runtime. You need to use `Sta from diffusers import StableDiffusionOnnxPipeline pipe = StableDiffusionOnnxPipeline.from_pretrained( - "CompVis/stable-diffusion-v1-4", + "runwayml/stable-diffusion-v1-5", revision="onnx", provider="CUDAExecutionProvider", ) diff --git a/docs/source/quicktour.mdx b/docs/source/quicktour.mdx index 9574ecac4a6a..a50b476c3d6a 100644 --- a/docs/source/quicktour.mdx +++ b/docs/source/quicktour.mdx @@ -41,7 +41,7 @@ In this guide though, you'll use [`DiffusionPipeline`] for text-to-image generat ```python >>> from diffusers import DiffusionPipeline ->>> generator = DiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256") +>>> pipeline = DiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256") ``` The [`DiffusionPipeline`] downloads and caches all modeling, tokenization, and scheduling components. @@ -49,13 +49,13 @@ Because the model consists of roughly 1.4 billion parameters, we strongly recomm You can move the generator object to GPU, just like you would in PyTorch. ```python ->>> generator.to("cuda") +>>> pipeline.to("cuda") ``` -Now you can use the `generator` on your text prompt: +Now you can use the `pipeline` on your text prompt: ```python ->>> image = generator("An image of a squirrel in Picasso style").images[0] +>>> image = pipeline("An image of a squirrel in Picasso style").images[0] ``` The output is by default wrapped into a [PIL Image object](https://pillow.readthedocs.io/en/stable/reference/Image.html?highlight=image#the-image-class). @@ -68,8 +68,7 @@ You can save the image by simply calling: More advanced models, like [Stable Diffusion](https://huggingface.co/CompVis/stable-diffusion) require you to accept a [license](https://huggingface.co/spaces/CompVis/stable-diffusion-license) before running the model. This is due to the improved image generation capabilities of the model and the potentially harmful content that could be produced with it. -Long story short: Head over to your stable diffusion model of choice, *e.g.* [`CompVis/stable-diffusion-v1-4`](https://huggingface.co/CompVis/stable-diffusion-v1-4), read through the license and click-accept to get -access to the model. +Please, head over to your stable diffusion model of choice, *e.g.* [`runwayml/stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5), read the license carefully and tick the checkbox if you agree. You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to [this section of the documentation](https://huggingface.co/docs/hub/security-tokens). Having "click-accepted" the license, you can save your token: @@ -77,13 +76,13 @@ Having "click-accepted" the license, you can save your token: AUTH_TOKEN = "" ``` -You can then load [`CompVis/stable-diffusion-v1-4`](https://huggingface.co/CompVis/stable-diffusion-v1-4) +You can then load [`runwayml/stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5) just like we did before only that now you need to pass your `AUTH_TOKEN`: ```python >>> from diffusers import DiffusionPipeline ->>> generator = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=AUTH_TOKEN) +>>> pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", use_auth_token=AUTH_TOKEN) ``` If you do not pass your authentication token you will see that the diffusion system will not be correctly @@ -95,15 +94,15 @@ the weights locally via: ``` git lfs install -git clone https://huggingface.co/CompVis/stable-diffusion-v1-4 +git clone https://huggingface.co/runwayml/stable-diffusion-v1-5 ``` and then load locally saved weights into the pipeline. This way, you do not need to pass an authentication -token. Assuming that `"./stable-diffusion-v1-4"` is the local path to the cloned stable-diffusion-v1-4 repo, +token. Assuming that `"./stable-diffusion-v1-5"` is the local path to the cloned stable-diffusion-v1-5 repo, you can also load the pipeline as follows: ```python ->>> generator = DiffusionPipeline.from_pretrained("./stable-diffusion-v1-4") +>>> pipeline = DiffusionPipeline.from_pretrained("./stable-diffusion-v1-5") ``` Running the pipeline is then identical to the code above as it's the same model architecture. @@ -116,19 +115,20 @@ Running the pipeline is then identical to the code above as it's the same model Diffusion systems can be used with multiple different [schedulers](./api/schedulers) each with their pros and cons. By default, Stable Diffusion runs with [`PNDMScheduler`], but it's very simple to -use a different scheduler. *E.g.* if you would instead like to use the [`LMSDiscreteScheduler`] scheduler, +use a different scheduler. *E.g.* if you would instead like to use the [`EulerDiscreteScheduler`] scheduler, you could use it as follows: ```python ->>> from diffusers import LMSDiscreteScheduler +>>> from diffusers import EulerDiscreteScheduler ->>> scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear") +>>> pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", use_auth_token=AUTH_TOKEN) ->>> generator = StableDiffusionPipeline.from_pretrained( -... "CompVis/stable-diffusion-v1-4", scheduler=scheduler, use_auth_token=AUTH_TOKEN -... ) +>>> # change scheduler to Euler +>>> pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config) ``` +For more in-detail information on how to change between schedulers, please refer to the [Using Schedulers](./using-diffusers/schedulers) guide. + [Stability AI's](https://stability.ai/) Stable Diffusion model is an impressive image generation model and can do much more than just generating images from text. We have dedicated a whole documentation page, just for Stable Diffusion [here](./conceptual/stable_diffusion). diff --git a/docs/source/training/dreambooth.mdx b/docs/source/training/dreambooth.mdx new file mode 100644 index 000000000000..238dcb24cf8f --- /dev/null +++ b/docs/source/training/dreambooth.mdx @@ -0,0 +1,240 @@ + + +# DreamBooth fine-tuning example + +[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text-to-image models like stable diffusion given just a few (3~5) images of a subject. + +![Dreambooth examples from the project's blog](https://dreambooth.github.io/DreamBooth_files/teaser_static.jpg) +_Dreambooth examples from the [project's blog](https://dreambooth.github.io)._ + +The [Dreambooth training script](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth) shows how to implement this training procedure on a pre-trained Stable Diffusion model. + + + + + +Dreambooth fine-tuning is very sensitive to hyperparameters and easy to overfit. We recommend you take a look at our [in-depth analysis](https://huggingface.co/blog/dreambooth) with recommended settings for different subjects, and go from there. + + + +## Training locally + +### Installing the dependencies + +Before running the scripts, make sure to install the library's training dependencies. We also recommend to install `diffusers` from the `main` github branch. + +```bash +pip install git+https://github.com/huggingface/diffusers +pip install -U -r diffusers/examples/dreambooth/requirements.txt +``` + +Then initialize and configure a [🤗 Accelerate](https://github.com/huggingface/accelerate/) environment with: + +```bash +accelerate config +``` + +You need to accept the model license before downloading or using the weights. In this example we'll use model version `v1-4`, so you'll need to visit [its card](https://huggingface.co/CompVis/stable-diffusion-v1-4), read the license and tick the checkbox if you agree. + +You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to [this section of the documentation](https://huggingface.co/docs/hub/security-tokens). + +Run the following command to authenticate your token + +```bash +huggingface-cli login +``` + +If you have already cloned the repo, then you won't need to go through these steps. Instead, you can pass the path to your local checkout to the training script and it will be loaded from there. + +### Dog toy example + +In this example we'll use [these images](https://drive.google.com/drive/folders/1BO_dyz-p65qhBRRMRA4TbZ8qW4rB99JZ) to add a new concept to Stable Diffusion using the Dreambooth process. They will be our training data. Please, download them and place them somewhere in your system. + +Then you can launch the training script using: + +```bash +export MODEL_NAME="CompVis/stable-diffusion-v1-4" +export INSTANCE_DIR="path_to_training_images" +export OUTPUT_DIR="path_to_saved_model" + +accelerate launch train_dreambooth.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --instance_data_dir=$INSTANCE_DIR \ + --output_dir=$OUTPUT_DIR \ + --instance_prompt="a photo of sks dog" \ + --resolution=512 \ + --train_batch_size=1 \ + --gradient_accumulation_steps=1 \ + --learning_rate=5e-6 \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --max_train_steps=400 +``` + +### Training with a prior-preserving loss + +Prior preservation is used to avoid overfitting and language-drift. Please, refer to the paper to learn more about it if you are interested. For prior preservation, we use other images of the same class as part of the training process. The nice thing is that we can generate those images using the Stable Diffusion model itself! The training script will save the generated images to a local path we specify. + +According to the paper, it's recommended to generate `num_epochs * num_samples` images for prior preservation. 200-300 works well for most cases. + +```bash +export MODEL_NAME="CompVis/stable-diffusion-v1-4" +export INSTANCE_DIR="path_to_training_images" +export CLASS_DIR="path_to_class_images" +export OUTPUT_DIR="path_to_saved_model" + +accelerate launch train_dreambooth.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --instance_data_dir=$INSTANCE_DIR \ + --class_data_dir=$CLASS_DIR \ + --output_dir=$OUTPUT_DIR \ + --with_prior_preservation --prior_loss_weight=1.0 \ + --instance_prompt="a photo of sks dog" \ + --class_prompt="a photo of dog" \ + --resolution=512 \ + --train_batch_size=1 \ + --gradient_accumulation_steps=1 \ + --learning_rate=5e-6 \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --num_class_images=200 \ + --max_train_steps=800 +``` + +### Training on a 16GB GPU + +With the help of gradient checkpointing and the 8-bit optimizer from [bitsandbytes](https://github.com/TimDettmers/bitsandbytes), it's possible to train dreambooth on a 16GB GPU. + +```bash +pip install bitsandbytes +``` + +Then pass the `--use_8bit_adam` option to the training script. + +```bash +export MODEL_NAME="CompVis/stable-diffusion-v1-4" +export INSTANCE_DIR="path_to_training_images" +export CLASS_DIR="path_to_class_images" +export OUTPUT_DIR="path_to_saved_model" + +accelerate launch train_dreambooth.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --instance_data_dir=$INSTANCE_DIR \ + --class_data_dir=$CLASS_DIR \ + --output_dir=$OUTPUT_DIR \ + --with_prior_preservation --prior_loss_weight=1.0 \ + --instance_prompt="a photo of sks dog" \ + --class_prompt="a photo of dog" \ + --resolution=512 \ + --train_batch_size=1 \ + --gradient_accumulation_steps=2 --gradient_checkpointing \ + --use_8bit_adam \ + --learning_rate=5e-6 \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --num_class_images=200 \ + --max_train_steps=800 +``` + +### Fine-tune the text encoder in addition to the UNet + +The script also allows to fine-tune the `text_encoder` along with the `unet`. It has been observed experimentally that this gives much better results, especially on faces. Please, refer to [our blog](https://huggingface.co/blog/dreambooth) for more details. + +To enable this option, pass the `--train_text_encoder` argument to the training script. + + +Training the text encoder requires additional memory, so training won't fit on a 16GB GPU. You'll need at least 24GB VRAM to use this option. + + +```bash +export MODEL_NAME="CompVis/stable-diffusion-v1-4" +export INSTANCE_DIR="path_to_training_images" +export CLASS_DIR="path_to_class_images" +export OUTPUT_DIR="path_to_saved_model" + +accelerate launch train_dreambooth.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_text_encoder \ + --instance_data_dir=$INSTANCE_DIR \ + --class_data_dir=$CLASS_DIR \ + --output_dir=$OUTPUT_DIR \ + --with_prior_preservation --prior_loss_weight=1.0 \ + --instance_prompt="a photo of sks dog" \ + --class_prompt="a photo of dog" \ + --resolution=512 \ + --train_batch_size=1 \ + --use_8bit_adam + --gradient_checkpointing \ + --learning_rate=2e-6 \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --num_class_images=200 \ + --max_train_steps=800 +``` + +### Training on a 8 GB GPU: + +Using [DeepSpeed](https://www.deepspeed.ai/) it's even possible to offload some +tensors from VRAM to either CPU or NVME, allowing training to proceed with less GPU memory. + +DeepSpeed needs to be enabled with `accelerate config`. During configuration, +answer yes to "Do you want to use DeepSpeed?". Combining DeepSpeed stage 2, fp16 +mixed precision, and offloading both the model parameters and the optimizer state to CPU, it's +possible to train on under 8 GB VRAM. The drawback is that this requires more system RAM (about 25 GB). See [the DeepSpeed documentation](https://huggingface.co/docs/accelerate/usage_guides/deepspeed) for more configuration options. + +Changing the default Adam optimizer to DeepSpeed's special version of Adam +`deepspeed.ops.adam.DeepSpeedCPUAdam` gives a substantial speedup, but enabling +it requires the system's CUDA toolchain version to be the same as the one installed with PyTorch. 8-bit optimizers don't seem to be compatible with DeepSpeed at the moment. + +```bash +export MODEL_NAME="CompVis/stable-diffusion-v1-4" +export INSTANCE_DIR="path_to_training_images" +export CLASS_DIR="path_to_class_images" +export OUTPUT_DIR="path_to_saved_model" + +accelerate launch train_dreambooth.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --instance_data_dir=$INSTANCE_DIR \ + --class_data_dir=$CLASS_DIR \ + --output_dir=$OUTPUT_DIR \ + --with_prior_preservation --prior_loss_weight=1.0 \ + --instance_prompt="a photo of sks dog" \ + --class_prompt="a photo of dog" \ + --resolution=512 \ + --train_batch_size=1 \ + --sample_batch_size=1 \ + --gradient_accumulation_steps=1 --gradient_checkpointing \ + --learning_rate=5e-6 \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --num_class_images=200 \ + --max_train_steps=800 \ + --mixed_precision=fp16 +``` + +## Inference + +Once you have trained a model, inference can be done using the `StableDiffusionPipeline`, by simply indicating the path where the model was saved. Make sure that your prompts include the special `identifier` used during training (`sks` in the previous examples). + +```python +from diffusers import StableDiffusionPipeline +import torch + +model_id = "path_to_saved_model" +pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda") + +prompt = "A photo of sks dog in a bucket" +image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0] + +image.save("dog-bucket.png") +``` diff --git a/docs/source/training/overview.mdx b/docs/source/training/overview.mdx index c403298493e5..9b36117ca43a 100644 --- a/docs/source/training/overview.mdx +++ b/docs/source/training/overview.mdx @@ -12,7 +12,7 @@ specific language governing permissions and limitations under the License. # 🧨 Diffusers Training Examples -Diffusers examples are a collection of scripts to demonstrate how to effectively use the `diffusers` library +Diffusers training examples are a collection of scripts to demonstrate how to effectively use the `diffusers` library for a variety of use cases. **Note**: If you are looking for **official** examples on how to use `diffusers` for inference, @@ -36,13 +36,15 @@ Training examples show how to pretrain or fine-tune diffusion models for a varie - [Unconditional Training](./unconditional_training) - [Text-to-Image Training](./text2image) - [Text Inversion](./text_inversion) +- [Dreambooth](./dreambooth) | Task | 🤗 Accelerate | 🤗 Datasets | Colab |---|---|:---:|:---:| | [**Unconditional Image Generation**](./unconditional_training) | ✅ | ✅ | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb) -| [**Text-to-Image**](./text2image) | - | - | -| [**Text-Inversion**](./text_inversion) | ✅ | ✅ | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_textual_inversion_training.ipynb) +| [**Text-to-Image fine-tuning**](./text2image) | ✅ | ✅ | +| [**Textual Inversion**](./text_inversion) | ✅ | - | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_textual_inversion_training.ipynb) +| [**Dreambooth**](./dreambooth) | ✅ | - | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_dreambooth_training.ipynb) ## Community diff --git a/docs/source/training/text2image.mdx b/docs/source/training/text2image.mdx index dcbdece429a9..eb71457cb758 100644 --- a/docs/source/training/text2image.mdx +++ b/docs/source/training/text2image.mdx @@ -11,6 +11,128 @@ specific language governing permissions and limitations under the License. --> -# Text-to-Image Training +# Stable Diffusion text-to-image fine-tuning -Under construction 🚧 +The [`train_text_to_image.py`](https://github.com/huggingface/diffusers/tree/main/examples/text_to_image) script shows how to fine-tune the stable diffusion model on your own dataset. + + + +The text-to-image fine-tuning script is experimental. It's easy to overfit and run into issues like catastrophic forgetting. We recommend to explore different hyperparameters to get the best results on your dataset. + + + + +## Running locally + +### Installing the dependencies + +Before running the scripts, make sure to install the library's training dependencies: + +```bash +pip install git+https://github.com/huggingface/diffusers.git +pip install -U -r requirements.txt +``` + +And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with: + +```bash +accelerate config +``` + +You need to accept the model license before downloading or using the weights. In this example we'll use model version `v1-4`, so you'll need to visit [its card](https://huggingface.co/CompVis/stable-diffusion-v1-4), read the license and tick the checkbox if you agree. + +You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to [this section of the documentation](https://huggingface.co/docs/hub/security-tokens). + +Run the following command to authenticate your token + +```bash +huggingface-cli login +``` + +If you have already cloned the repo, then you won't need to go through these steps. Instead, you can pass the path to your local checkout to the training script and it will be loaded from there. + +### Hardware Requirements for Fine-tuning + +Using `gradient_checkpointing` and `mixed_precision` it should be possible to fine tune the model on a single 24GB GPU. For higher `batch_size` and faster training it's better to use GPUs with more than 30GB of GPU memory. You can also use JAX / Flax for fine-tuning on TPUs or GPUs, see [below](#flax-jax-finetuning) for details. + +### Fine-tuning Example + +The following script will launch a fine-tuning run using [Justin Pinkneys' captioned Pokemon dataset](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions), available in Hugging Face Hub. + + +```bash +export MODEL_NAME="CompVis/stable-diffusion-v1-4" +export dataset_name="lambdalabs/pokemon-blip-captions" + +accelerate launch train_text_to_image.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --dataset_name=$dataset_name \ + --use_ema \ + --resolution=512 --center_crop --random_flip \ + --train_batch_size=1 \ + --gradient_accumulation_steps=4 \ + --gradient_checkpointing \ + --mixed_precision="fp16" \ + --max_train_steps=15000 \ + --learning_rate=1e-05 \ + --max_grad_norm=1 \ + --lr_scheduler="constant" --lr_warmup_steps=0 \ + --output_dir="sd-pokemon-model" +``` + +To run on your own training files you need to prepare the dataset according to the format required by `datasets`. You can upload your dataset to the Hub, or you can prepare a local folder with your files. [This documentation](https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder-with-metadata) explains how to do it. + +You should modify the script if you wish to use custom loading logic. We have left pointers in the code in the appropriate places :) + +```bash +export MODEL_NAME="CompVis/stable-diffusion-v1-4" +export TRAIN_DIR="path_to_your_dataset" +export OUTPUT_DIR="path_to_save_model" + +accelerate launch train_text_to_image.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$TRAIN_DIR \ + --use_ema \ + --resolution=512 --center_crop --random_flip \ + --train_batch_size=1 \ + --gradient_accumulation_steps=4 \ + --gradient_checkpointing \ + --mixed_precision="fp16" \ + --max_train_steps=15000 \ + --learning_rate=1e-05 \ + --max_grad_norm=1 \ + --lr_scheduler="constant" --lr_warmup_steps=0 \ + --output_dir=${OUTPUT_DIR} +``` + +Once training is finished the model will be saved to the `OUTPUT_DIR` specified in the command. To load the fine-tuned model for inference, just pass that path to `StableDiffusionPipeline`: + +```python +from diffusers import StableDiffusionPipeline + +model_path = "path_to_saved_model" +pipe = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16) +pipe.to("cuda") + +image = pipe(prompt="yoda").images[0] +image.save("yoda-pokemon.png") +``` + +### Flax / JAX fine-tuning + +Thanks to [@duongna211](https://github.com/duongna21) it's possible to fine-tune Stable Diffusion using Flax! This is very efficient on TPU hardware but works great on GPUs too. You can use the [Flax training script](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_flax.py) like this: + +```Python +export MODEL_NAME="runwayml/stable-diffusion-v1-5" +export dataset_name="lambdalabs/pokemon-blip-captions" + +python train_text_to_image_flax.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --dataset_name=$dataset_name \ + --resolution=512 --center_crop --random_flip \ + --train_batch_size=1 \ + --max_train_steps=15000 \ + --learning_rate=1e-05 \ + --max_grad_norm=1 \ + --output_dir="sd-pokemon-model" +``` diff --git a/docs/source/training/text_inversion.mdx b/docs/source/training/text_inversion.mdx index 13ea7c942b4e..7bc145299eac 100644 --- a/docs/source/training/text_inversion.mdx +++ b/docs/source/training/text_inversion.mdx @@ -49,7 +49,7 @@ The `textual_inversion.py` script [here](https://github.com/huggingface/diffuser ### Installing the dependencies -Before running the scripts, make sure to install the library's training dependencies: +Before running the scripts, make sure to install the library's training dependencies. ```bash pip install diffusers[training] accelerate transformers @@ -64,7 +64,7 @@ accelerate config ### Cat toy example -You need to accept the model license before downloading or using the weights. In this example we'll use model version `v1-4`, so you'll need to visit [its card](https://huggingface.co/CompVis/stable-diffusion-v1-4), read the license and tick the checkbox if you agree. +You need to accept the model license before downloading or using the weights. In this example we'll use model version `v1-4`, so you'll need to visit [its card](https://huggingface.co/CompVis/stable-diffusion-v1-4), read the license and tick the checkbox if you agree. You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to [this section of the documentation](https://huggingface.co/docs/hub/security-tokens). @@ -83,7 +83,7 @@ Now let's get our dataset.Download 3-4 images from [here](https://drive.google.c And launch the training using ```bash -export MODEL_NAME="CompVis/stable-diffusion-v1-4" +export MODEL_NAME="runwayml/stable-diffusion-v1-5" export DATA_DIR="path-to-dir-containing-images" accelerate launch textual_inversion.py \ diff --git a/docs/source/using-diffusers/audio.mdx b/docs/source/using-diffusers/audio.mdx new file mode 100644 index 000000000000..5a5c2241ca75 --- /dev/null +++ b/docs/source/using-diffusers/audio.mdx @@ -0,0 +1,16 @@ + + +# Using Diffusers for audio + +The [`DanceDiffusionPipeline`] can be used to generate audio rapidly! +More coming soon! \ No newline at end of file diff --git a/docs/source/using-diffusers/conditional_image_generation.mdx b/docs/source/using-diffusers/conditional_image_generation.mdx index 6273a71d4c0a..5ed27ac9171c 100644 --- a/docs/source/using-diffusers/conditional_image_generation.mdx +++ b/docs/source/using-diffusers/conditional_image_generation.mdx @@ -44,5 +44,3 @@ You can save the image by simply calling: ```python >>> image.save("image_of_squirrel_painting.png") ``` - - diff --git a/docs/source/using-diffusers/configuration.mdx b/docs/source/using-diffusers/configuration.mdx index a64cce0606be..36a2ad0d0394 100644 --- a/docs/source/using-diffusers/configuration.mdx +++ b/docs/source/using-diffusers/configuration.mdx @@ -12,21 +12,10 @@ specific language governing permissions and limitations under the License. -# Quicktour +# Configuration -Start using Diffusers🧨 quickly! -To start, use the [`DiffusionPipeline`] for quick inference and sample generations! - -``` -pip install diffusers -``` - -## Main classes - -### Models - -### Schedulers - -### Pipelines +The handling of configurations in Diffusers is with the `ConfigMixin` class. +[[autodoc]] ConfigMixin +Under further construction 🚧, open a [PR](https://github.com/huggingface/diffusers/compare) if you want to contribute! diff --git a/docs/source/using-diffusers/contribute_pipeline.mdx b/docs/source/using-diffusers/contribute_pipeline.mdx new file mode 100644 index 000000000000..18e84cdfbc94 --- /dev/null +++ b/docs/source/using-diffusers/contribute_pipeline.mdx @@ -0,0 +1,169 @@ + + +# How to build a community pipeline + +*Note*: this page was built from the GitHub Issue on Community Pipelines [#841](https://github.com/huggingface/diffusers/issues/841). + +Let's make an example! +Say you want to define a pipeline that just does a single forward pass to a U-Net and then calls a scheduler only once (Note, this doesn't make any sense from a scientific point of view, but only represents an example of how things work under the hood). + +Cool! So you open your favorite IDE and start creating your pipeline 💻. +First, what model weights and configurations do we need? +We have a U-Net and a scheduler, so our pipeline should take a U-Net and a scheduler as an argument. +Also, as stated above, you'd like to be able to load weights and the scheduler config for Hub and share your code with others, so we'll inherit from `DiffusionPipeline`: + +```python +from diffusers import DiffusionPipeline +import torch + + +class UnetSchedulerOneForwardPipeline(DiffusionPipeline): + def __init__(self, unet, scheduler): + super().__init__() +``` + +Now, we must save the `unet` and `scheduler` in a config file so that you can save your pipeline with `save_pretrained`. +Therefore, make sure you add every component that is save-able to the `register_modules` function: + +```python +from diffusers import DiffusionPipeline +import torch + + +class UnetSchedulerOneForwardPipeline(DiffusionPipeline): + def __init__(self, unet, scheduler): + super().__init__() + + self.register_modules(unet=unet, scheduler=scheduler) +``` + +Cool, the init is done! 🔥 Now, let's go into the forward pass, which we recommend defining as `__call__` . Here you're given all the creative freedom there is. For our amazing "one-step" pipeline, we simply create a random image and call the unet once and the scheduler once: + +```python +from diffusers import DiffusionPipeline +import torch + + +class UnetSchedulerOneForwardPipeline(DiffusionPipeline): + def __init__(self, unet, scheduler): + super().__init__() + + self.register_modules(unet=unet, scheduler=scheduler) + + def __call__(self): + image = torch.randn( + (1, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), + ) + timestep = 1 + + model_output = self.unet(image, timestep).sample + scheduler_output = self.scheduler.step(model_output, timestep, image).prev_sample + + return scheduler_output +``` + +Cool, that's it! 🚀 You can now run this pipeline by passing a `unet` and a `scheduler` to the init: + +```python +from diffusers import DDPMScheduler, Unet2DModel + +scheduler = DDPMScheduler() +unet = UNet2DModel() + +pipeline = UnetSchedulerOneForwardPipeline(unet=unet, scheduler=scheduler) + +output = pipeline() +``` + +But what's even better is that you can load pre-existing weights into the pipeline if they match exactly your pipeline structure. This is e.g. the case for [https://huggingface.co/google/ddpm-cifar10-32](https://huggingface.co/google/ddpm-cifar10-32) so that we can do the following: + +```python +pipeline = UnetSchedulerOneForwardPipeline.from_pretrained("google/ddpm-cifar10-32") + +output = pipeline() +``` + +We want to share this amazing pipeline with the community, so we would open a PR request to add the following code under `one_step_unet.py` to [https://github.com/huggingface/diffusers/tree/main/examples/community](https://github.com/huggingface/diffusers/tree/main/examples/community) . + +```python +from diffusers import DiffusionPipeline +import torch + + +class UnetSchedulerOneForwardPipeline(DiffusionPipeline): + def __init__(self, unet, scheduler): + super().__init__() + + self.register_modules(unet=unet, scheduler=scheduler) + + def __call__(self): + image = torch.randn( + (1, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), + ) + timestep = 1 + + model_output = self.unet(image, timestep).sample + scheduler_output = self.scheduler.step(model_output, timestep, image).prev_sample + + return scheduler_output +``` + +Our amazing pipeline got merged here: [#840](https://github.com/huggingface/diffusers/pull/840). +Now everybody that has `diffusers >= 0.4.0` installed can use our pipeline magically 🪄 as follows: + +```python +from diffusers import DiffusionPipeline + +pipe = DiffusionPipeline.from_pretrained("google/ddpm-cifar10-32", custom_pipeline="one_step_unet") +pipe() +``` + +Another way to upload your custom_pipeline, besides sending a PR, is uploading the code that contains it to the Hugging Face Hub, [as exemplified here](https://huggingface.co/docs/diffusers/using-diffusers/custom_pipeline_overview#loading-custom-pipelines-from-the-hub). + +**Try it out now - it works!** + +In general, you will want to create much more sophisticated pipelines, so we recommend looking at existing pipelines here: [https://github.com/huggingface/diffusers/tree/main/examples/community](https://github.com/huggingface/diffusers/tree/main/examples/community). + +IMPORTANT: +You can use whatever package you want in your community pipeline file - as long as the user has it installed, everything will work fine. Make sure you have one and only one pipeline class that inherits from `DiffusionPipeline` as this will be automatically detected. + +## How do community pipelines work? +A community pipeline is a class that has to inherit from ['DiffusionPipeline']: +and that has been added to `examples/community` [files](https://github.com/huggingface/diffusers/tree/main/examples/community). +The community can load the pipeline code via the custom_pipeline argument from DiffusionPipeline. See docs [here](https://huggingface.co/docs/diffusers/api/diffusion_pipeline#diffusers.DiffusionPipeline.from_pretrained.custom_pipeline): + +This means: +The model weights and configs of the pipeline should be loaded from the `pretrained_model_name_or_path` [argument](https://huggingface.co/docs/diffusers/api/diffusion_pipeline#diffusers.DiffusionPipeline.from_pretrained.pretrained_model_name_or_path): +whereas the code that powers the community pipeline is defined in a file added in [`examples/community`](https://github.com/huggingface/diffusers/tree/main/examples/community). + +Now, it might very well be that only some of your pipeline components weights can be downloaded from an official repo. +The other components should then be passed directly to init as is the case for the ClIP guidance notebook [here](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/CLIP_Guided_Stable_diffusion_with_diffusers.ipynb#scrollTo=z9Kglma6hjki). + +The magic behind all of this is that we load the code directly from GitHub. You can check it out in more detail if you follow the functionality defined here: + +```python +# 2. Load the pipeline class, if using custom module then load it from the hub +# if we load from explicit class, let's use it +if custom_pipeline is not None: + pipeline_class = get_class_from_dynamic_module( + custom_pipeline, module_file=CUSTOM_PIPELINE_FILE_NAME, cache_dir=custom_pipeline + ) +elif cls != DiffusionPipeline: + pipeline_class = cls +else: + diffusers_module = importlib.import_module(cls.__module__.split(".")[0]) + pipeline_class = getattr(diffusers_module, config_dict["_class_name"]) +``` + +This is why a community pipeline merged to GitHub will be directly available to all `diffusers` packages. + diff --git a/docs/source/using-diffusers/custom_pipeline_examples.mdx b/docs/source/using-diffusers/custom_pipeline_examples.mdx new file mode 100644 index 000000000000..b77e33be77d0 --- /dev/null +++ b/docs/source/using-diffusers/custom_pipeline_examples.mdx @@ -0,0 +1,283 @@ + + +# Custom Pipelines + +> **For more information about community pipelines, please have a look at [this issue](https://github.com/huggingface/diffusers/issues/841).** + +**Community** examples consist of both inference and training examples that have been added by the community. +Please have a look at the following table to get an overview of all community examples. Click on the **Code Example** to get a copy-and-paste ready code example that you can try out. +If a community doesn't work as expected, please open an issue and ping the author on it. + +| Example | Description | Code Example | Colab | Author | +|:---------------------------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------------------------------------------------------------------|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------:| +| CLIP Guided Stable Diffusion | Doing CLIP guidance for text to image generation with Stable Diffusion | [CLIP Guided Stable Diffusion](#clip-guided-stable-diffusion) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/CLIP_Guided_Stable_diffusion_with_diffusers.ipynb) | [Suraj Patil](https://github.com/patil-suraj/) | +| One Step U-Net (Dummy) | Example showcasing of how to use Community Pipelines (see https://github.com/huggingface/diffusers/issues/841) | [One Step U-Net](#one-step-unet) | - | [Patrick von Platen](https://github.com/patrickvonplaten/) | +| Stable Diffusion Interpolation | Interpolate the latent space of Stable Diffusion between different prompts/seeds | [Stable Diffusion Interpolation](#stable-diffusion-interpolation) | - | [Nate Raw](https://github.com/nateraw/) | +| Stable Diffusion Mega | **One** Stable Diffusion Pipeline with all functionalities of [Text2Image](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py), [Image2Image](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py) and [Inpainting](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py) | [Stable Diffusion Mega](#stable-diffusion-mega) | - | [Patrick von Platen](https://github.com/patrickvonplaten/) | +| Long Prompt Weighting Stable Diffusion | **One** Stable Diffusion Pipeline without tokens length limit, and support parsing weighting in prompt. | [Long Prompt Weighting Stable Diffusion](#long-prompt-weighting-stable-diffusion) | - | [SkyTNT](https://github.com/SkyTNT) | +| Speech to Image | Using automatic-speech-recognition to transcribe text and Stable Diffusion to generate images | [Speech to Image](#speech-to-image) | - | [Mikail Duzenli](https://github.com/MikailINTech) + +To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly. +```py +pipe = DiffusionPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", custom_pipeline="filename_in_the_community_folder" +) +``` + +## Example usages + +### CLIP Guided Stable Diffusion + +CLIP guided stable diffusion can help to generate more realistic images +by guiding stable diffusion at every denoising step with an additional CLIP model. + +The following code requires roughly 12GB of GPU RAM. + +```python +from diffusers import DiffusionPipeline +from transformers import CLIPFeatureExtractor, CLIPModel +import torch + + +feature_extractor = CLIPFeatureExtractor.from_pretrained("laion/CLIP-ViT-B-32-laion2B-s34B-b79K") +clip_model = CLIPModel.from_pretrained("laion/CLIP-ViT-B-32-laion2B-s34B-b79K", torch_dtype=torch.float16) + + +guided_pipeline = DiffusionPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", + custom_pipeline="clip_guided_stable_diffusion", + clip_model=clip_model, + feature_extractor=feature_extractor, + revision="fp16", + torch_dtype=torch.float16, +) +guided_pipeline.enable_attention_slicing() +guided_pipeline = guided_pipeline.to("cuda") + +prompt = "fantasy book cover, full moon, fantasy forest landscape, golden vector elements, fantasy magic, dark light night, intricate, elegant, sharp focus, illustration, highly detailed, digital painting, concept art, matte, art by WLOP and Artgerm and Albert Bierstadt, masterpiece" + +generator = torch.Generator(device="cuda").manual_seed(0) +images = [] +for i in range(4): + image = guided_pipeline( + prompt, + num_inference_steps=50, + guidance_scale=7.5, + clip_guidance_scale=100, + num_cutouts=4, + use_cutouts=False, + generator=generator, + ).images[0] + images.append(image) + +# save images locally +for i, img in enumerate(images): + img.save(f"./clip_guided_sd/image_{i}.png") +``` + +The `images` list contains a list of PIL images that can be saved locally or displayed directly in a google colab. +Generated images tend to be of higher qualtiy than natively using stable diffusion. E.g. the above script generates the following images: + +![clip_guidance](https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/clip_guidance/merged_clip_guidance.jpg). + +### One Step Unet + +The dummy "one-step-unet" can be run as follows: + +```python +from diffusers import DiffusionPipeline + +pipe = DiffusionPipeline.from_pretrained("google/ddpm-cifar10-32", custom_pipeline="one_step_unet") +pipe() +``` + +**Note**: This community pipeline is not useful as a feature, but rather just serves as an example of how community pipelines can be added (see https://github.com/huggingface/diffusers/issues/841). + +### Stable Diffusion Interpolation + +The following code can be run on a GPU of at least 8GB VRAM and should take approximately 5 minutes. + +```python +from diffusers import DiffusionPipeline +import torch + +pipe = DiffusionPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", + revision="fp16", + torch_dtype=torch.float16, + safety_checker=None, # Very important for videos...lots of false positives while interpolating + custom_pipeline="interpolate_stable_diffusion", +).to("cuda") +pipe.enable_attention_slicing() + +frame_filepaths = pipe.walk( + prompts=["a dog", "a cat", "a horse"], + seeds=[42, 1337, 1234], + num_interpolation_steps=16, + output_dir="./dreams", + batch_size=4, + height=512, + width=512, + guidance_scale=8.5, + num_inference_steps=50, +) +``` + +The output of the `walk(...)` function returns a list of images saved under the folder as defined in `output_dir`. You can use these images to create videos of stable diffusion. + +> **Please have a look at https://github.com/nateraw/stable-diffusion-videos for more in-detail information on how to create videos using stable diffusion as well as more feature-complete functionality.** + +### Stable Diffusion Mega + +The Stable Diffusion Mega Pipeline lets you use the main use cases of the stable diffusion pipeline in a single class. + +```python +#!/usr/bin/env python3 +from diffusers import DiffusionPipeline +import PIL +import requests +from io import BytesIO +import torch + + +def download_image(url): + response = requests.get(url) + return PIL.Image.open(BytesIO(response.content)).convert("RGB") + + +pipe = DiffusionPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", + custom_pipeline="stable_diffusion_mega", + torch_dtype=torch.float16, + revision="fp16", +) +pipe.to("cuda") +pipe.enable_attention_slicing() + + +### Text-to-Image + +images = pipe.text2img("An astronaut riding a horse").images + +### Image-to-Image + +init_image = download_image( + "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" +) + +prompt = "A fantasy landscape, trending on artstation" + +images = pipe.img2img(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5).images + +### Inpainting + +img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" +mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" +init_image = download_image(img_url).resize((512, 512)) +mask_image = download_image(mask_url).resize((512, 512)) + +prompt = "a cat sitting on a bench" +images = pipe.inpaint(prompt=prompt, init_image=init_image, mask_image=mask_image, strength=0.75).images +``` + +As shown above this one pipeline can run all both "text-to-image", "image-to-image", and "inpainting" in one pipeline. + +### Long Prompt Weighting Stable Diffusion + +The Pipeline lets you input prompt without 77 token length limit. And you can increase words weighting by using "()" or decrease words weighting by using "[]" +The Pipeline also lets you use the main use cases of the stable diffusion pipeline in a single class. + +#### pytorch + +```python +from diffusers import DiffusionPipeline +import torch + +pipe = DiffusionPipeline.from_pretrained( + "hakurei/waifu-diffusion", custom_pipeline="lpw_stable_diffusion", revision="fp16", torch_dtype=torch.float16 +) +pipe = pipe.to("cuda") + +prompt = "best_quality (1girl:1.3) bow bride brown_hair closed_mouth frilled_bow frilled_hair_tubes frills (full_body:1.3) fox_ear hair_bow hair_tubes happy hood japanese_clothes kimono long_sleeves red_bow smile solo tabi uchikake white_kimono wide_sleeves cherry_blossoms" +neg_prompt = "lowres, bad_anatomy, error_body, error_hair, error_arm, error_hands, bad_hands, error_fingers, bad_fingers, missing_fingers, error_legs, bad_legs, multiple_legs, missing_legs, error_lighting, error_shadow, error_reflection, text, error, extra_digit, fewer_digits, cropped, worst_quality, low_quality, normal_quality, jpeg_artifacts, signature, watermark, username, blurry" + +pipe.text2img(prompt, negative_prompt=neg_prompt, width=512, height=512, max_embeddings_multiples=3).images[0] +``` + +#### onnxruntime + +```python +from diffusers import DiffusionPipeline +import torch + +pipe = DiffusionPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", + custom_pipeline="lpw_stable_diffusion_onnx", + revision="onnx", + provider="CUDAExecutionProvider", +) + +prompt = "a photo of an astronaut riding a horse on mars, best quality" +neg_prompt = "lowres, bad anatomy, error body, error hair, error arm, error hands, bad hands, error fingers, bad fingers, missing fingers, error legs, bad legs, multiple legs, missing legs, error lighting, error shadow, error reflection, text, error, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry" + +pipe.text2img(prompt, negative_prompt=neg_prompt, width=512, height=512, max_embeddings_multiples=3).images[0] +``` + +if you see `Token indices sequence length is longer than the specified maximum sequence length for this model ( *** > 77 ) . Running this sequence through the model will result in indexing errors`. Do not worry, it is normal. + +### Speech to Image + +The following code can generate an image from an audio sample using pre-trained OpenAI whisper-small and Stable Diffusion. + +```Python +import torch + +import matplotlib.pyplot as plt +from datasets import load_dataset +from diffusers import DiffusionPipeline +from transformers import ( + WhisperForConditionalGeneration, + WhisperProcessor, +) + + +device = "cuda" if torch.cuda.is_available() else "cpu" + +ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + +audio_sample = ds[3] + +text = audio_sample["text"].lower() +speech_data = audio_sample["audio"]["array"] + +model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small").to(device) +processor = WhisperProcessor.from_pretrained("openai/whisper-small") + +diffuser_pipeline = DiffusionPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", + custom_pipeline="speech_to_image_diffusion", + speech_model=model, + speech_processor=processor, + revision="fp16", + torch_dtype=torch.float16, +) + +diffuser_pipeline.enable_attention_slicing() +diffuser_pipeline = diffuser_pipeline.to(device) + +output = diffuser_pipeline(speech_data) +plt.imshow(output.images[0]) +``` +This example produces the following image: + +![image](https://user-images.githubusercontent.com/45072645/196901736-77d9c6fc-63ee-4072-90b0-dc8b903d63e3.png) \ No newline at end of file diff --git a/docs/source/using-diffusers/custom_pipelines.mdx b/docs/source/using-diffusers/custom_pipeline_overview.mdx similarity index 98% rename from docs/source/using-diffusers/custom_pipelines.mdx rename to docs/source/using-diffusers/custom_pipeline_overview.mdx index b52d405581b1..ae5bad2d7bf2 100644 --- a/docs/source/using-diffusers/custom_pipelines.mdx +++ b/docs/source/using-diffusers/custom_pipeline_overview.mdx @@ -10,7 +10,7 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o specific language governing permissions and limitations under the License. --> -# Custom Pipelines +# Loading and Adding Custom Pipelines Diffusers allows you to conveniently load any custom pipeline from the Hugging Face Hub as well as any [official community pipeline](https://github.com/huggingface/diffusers/tree/main/examples/community) via the [`DiffusionPipeline`] class. @@ -58,7 +58,7 @@ feature_extractor = CLIPFeatureExtractor.from_pretrained(clip_model_id) clip_model = CLIPModel.from_pretrained(clip_model_id) pipeline = DiffusionPipeline.from_pretrained( - "CompVis/stable-diffusion-v1-4", + "runwayml/stable-diffusion-v1-5", custom_pipeline="clip_guided_stable_diffusion", clip_model=clip_model, feature_extractor=feature_extractor, diff --git a/docs/source/using-diffusers/img2img.mdx b/docs/source/using-diffusers/img2img.mdx index 62eaeea911c9..911d7bd76ad0 100644 --- a/docs/source/using-diffusers/img2img.mdx +++ b/docs/source/using-diffusers/img2img.mdx @@ -25,7 +25,7 @@ from diffusers import StableDiffusionImg2ImgPipeline # load the pipeline device = "cuda" pipe = StableDiffusionImg2ImgPipeline.from_pretrained( - "CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16 + "runwayml/stable-diffusion-v1-5", revision="fp16", torch_dtype=torch.float16 ).to(device) # let's download an initial image @@ -33,7 +33,7 @@ url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/st response = requests.get(url) init_image = Image.open(BytesIO(response.content)).convert("RGB") -init_image = init_image.resize((768, 512)) +init_image.thumbnail((768, 768)) prompt = "A fantasy landscape, trending on artstation" diff --git a/docs/source/using-diffusers/inpaint.mdx b/docs/source/using-diffusers/inpaint.mdx index 7b4687c21204..1bafa244559b 100644 --- a/docs/source/using-diffusers/inpaint.mdx +++ b/docs/source/using-diffusers/inpaint.mdx @@ -12,13 +12,19 @@ specific language governing permissions and limitations under the License. # Text-Guided Image-Inpainting -The [`StableDiffusionInpaintPipeline`] lets you edit specific parts of an image by providing a mask and text prompt. +The [`StableDiffusionInpaintPipeline`] lets you edit specific parts of an image by providing a mask and a text prompt. It uses a version of Stable Diffusion specifically trained for in-painting tasks. -```python -from io import BytesIO + +Note that this model is distributed separately from the regular Stable Diffusion model, so you have to accept its license even if you accepted the Stable Diffusion one in the past. -import requests +Please, visit the [model card](https://huggingface.co/runwayml/stable-diffusion-inpainting), read the license carefully and tick the checkbox if you agree. You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to [this section](https://huggingface.co/docs/hub/security-tokens) of the documentation. + + +```python import PIL +import requests +import torch +from io import BytesIO from diffusers import StableDiffusionInpaintPipeline @@ -34,15 +40,24 @@ mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data init_image = download_image(img_url).resize((512, 512)) mask_image = download_image(mask_url).resize((512, 512)) -device = "cuda" pipe = StableDiffusionInpaintPipeline.from_pretrained( - "CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16 -).to(device) + "runwayml/stable-diffusion-inpainting", + revision="fp16", + torch_dtype=torch.float16, +) +pipe = pipe.to("cuda") + +prompt = "Face of a yellow cat, high resolution, sitting on a park bench" +image = pipe(prompt=prompt, image=init_image, mask_image=mask_image).images[0] +``` -prompt = "a cat sitting on a bench" -images = pipe(prompt=prompt, init_image=init_image, mask_image=mask_image, strength=0.75).images +`image` | `mask_image` | `prompt` | **Output** | +:-------------------------:|:-------------------------:|:-------------------------:|-------------------------:| +drawing | drawing | ***Face of a yellow cat, high resolution, sitting on a park bench*** | drawing | -images[0].save("cat_on_bench.png") -``` You can also run this example on colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb) + + +A previous experimental implementation of in-painting used a different, lower-quality process. To ensure backwards compatibility, loading a pretrained pipeline that doesn't contain the new model will still apply the old in-painting method. + \ No newline at end of file diff --git a/docs/source/using-diffusers/loading.mdx b/docs/source/using-diffusers/loading.mdx index 44b514bbb299..c97ad5c5d0c9 100644 --- a/docs/source/using-diffusers/loading.mdx +++ b/docs/source/using-diffusers/loading.mdx @@ -10,6 +10,389 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o specific language governing permissions and limitations under the License. --> -# Loading +# Loading -Under construction 🚧 +A core premise of the diffusers library is to make diffusion models **as accessible as possible**. +Accessibility is therefore achieved by providing an API to load complete diffusion pipelines as well as individual components with a single line of code. + +In the following we explain in-detail how to easily load: + +- *Complete Diffusion Pipelines* via the [`DiffusionPipeline.from_pretrained`] +- *Diffusion Models* via [`ModelMixin.from_pretrained`] +- *Schedulers* via [`SchedulerMixin.from_pretrained`] + +## Loading pipelines + +The [`DiffusionPipeline`] class is the easiest way to access any diffusion model that is [available on the Hub](https://huggingface.co/models?library=diffusers). Let's look at an example on how to download [CompVis' Latent Diffusion model](https://huggingface.co/CompVis/ldm-text2im-large-256). + +```python +from diffusers import DiffusionPipeline + +repo_id = "CompVis/ldm-text2im-large-256" +ldm = DiffusionPipeline.from_pretrained(repo_id) +``` + +Here [`DiffusionPipeline`] automatically detects the correct pipeline (*i.e.* [`LDMTextToImagePipeline`]), downloads and caches all required configuration and weight files (if not already done so), and finally returns a pipeline instance, called `ldm`. +The pipeline instance can then be called using [`LDMTextToImagePipeline.__call__`] (i.e., `ldm("image of a astronaut riding a horse")`) for text-to-image generation. + +Instead of using the generic [`DiffusionPipeline`] class for loading, you can also load the appropriate pipeline class directly. The code snippet above yields the same instance as when doing: + +```python +from diffusers import LDMTextToImagePipeline + +repo_id = "CompVis/ldm-text2im-large-256" +ldm = LDMTextToImagePipeline.from_pretrained(repo_id) +``` + +Diffusion pipelines like `LDMTextToImagePipeline` often consist of multiple components. These components can be both parameterized models, such as `"unet"`, `"vqvae"` and "bert", tokenizers or schedulers. These components can interact in complex ways with each other when using the pipeline in inference, *e.g.* for [`LDMTextToImagePipeline`] or [`StableDiffusionPipeline`] the inference call is explained [here](https://huggingface.co/blog/stable_diffusion#how-does-stable-diffusion-work). +The purpose of the [pipeline classes](./api/overview#diffusers-summary) is to wrap the complexity of these diffusion systems and give the user an easy-to-use API while staying flexible for customization, as will be shown later. + +### Loading pipelines that require access request + +Due to the capabilities of diffusion models to generate extremely realistic images, there is a certain danger that such models might be misused for unwanted applications, *e.g.* generating pornography or violent images. +In order to minimize the possibility of such unsolicited use cases, some of the most powerful diffusion models require users to acknowledge a license before being able to use the model. If the user does not agree to the license, the pipeline cannot be downloaded. +If you try to load [`runwayml/stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5) the same way as done previously: + +```python +from diffusers import DiffusionPipeline + +repo_id = "runwayml/stable-diffusion-v1-5" +stable_diffusion = DiffusionPipeline.from_pretrained(repo_id) +``` + +it will only work if you have both *click-accepted* the license on [the model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) and are logged into the Hugging Face Hub. Otherwise you will get an error message +such as the following: + +``` +OSError: runwayml/stable-diffusion-v1-5 is not a local folder and is not a valid model identifier listed on 'https://huggingface.co/models' +If this is a private repository, make sure to pass a token having permission to this repo with `use_auth_token` or log in with `huggingface-cli login` +``` + +Therefore, we need to make sure to *click-accept* the license. You can do this by simply visiting +the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) and clicking on "Agree and access repository": + +

+
+ +
+

+ +Second, you need to login with your access token: + +``` +huggingface-cli login +``` + +before trying to load the model. Or alternatively, you can pass [your access token](https://huggingface.co/docs/hub/security-tokens#user-access-tokens) directly via the flag `use_auth_token`. In this case you do **not** need +to run `huggingface-cli login` before: + +```python +from diffusers import DiffusionPipeline + +repo_id = "runwayml/stable-diffusion-v1-5" +stable_diffusion = DiffusionPipeline.from_pretrained(repo_id, use_auth_token="") +``` + +The final option to use pipelines that require access without having to rely on the Hugging Face Hub is to load the pipeline locally as explained in the next section. + +### Loading pipelines locally + +If you prefer to have complete control over the pipeline and its corresponding files or, as said before, if you want to use pipelines that require an access request without having to be connected to the Hugging Face Hub, +we recommend loading pipelines locally. + +To load a diffusion pipeline locally, you first need to manually download the whole folder structure on your local disk and then pass a local path to the [`DiffusionPipeline.from_pretrained`]. Let's again look at an example for +[CompVis' Latent Diffusion model](https://huggingface.co/CompVis/ldm-text2im-large-256). + +First, you should make use of [`git-lfs`](https://git-lfs.github.com/) to download the whole folder structure that has been uploaded to the [model repository](https://huggingface.co/CompVis/ldm-text2im-large-256/tree/main): + +``` +git lfs install +git clone https://huggingface.co/runwayml/stable-diffusion-v1-5 +``` + +The command above will create a local folder called `./stable-diffusion-v1-5` on your disk. +Now, all you have to do is to simply pass the local folder path to `from_pretrained`: + +```python +from diffusers import DiffusionPipeline + +repo_id = "./stable-diffusion-v1-5" +stable_diffusion = DiffusionPipeline.from_pretrained(repo_id) +``` + +If `repo_id` is a local path, as it is the case here, [`DiffusionPipeline.from_pretrained`] will automatically detect it and therefore not try to download any files from the Hub. +While we usually recommend to load weights directly from the Hub to be certain to stay up to date with the newest changes, loading pipelines locally should be preferred if one +wants to stay anonymous, self-contained applications, etc... + +### Loading customized pipelines + +Advanced users that want to load customized versions of diffusion pipelines can do so by swapping any of the default components, *e.g.* the scheduler, with other scheduler classes. +A classical use case of this functionality is to swap the scheduler. [Stable Diffusion v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) uses the [`PNDMScheduler`] by default which is generally not the most performant scheduler. Since the release +of stable diffusion, multiple improved schedulers have been published. To use those, the user has to manually load their preferred scheduler and pass it into [`DiffusionPipeline.from_pretrained`]. + +*E.g.* to use [`EulerDiscreteScheduler`] or [`DPMSolverMultistepScheduler`] to have a better quality vs. generation speed trade-off for inference, one could load them as follows: + +```python +from diffusers import DiffusionPipeline, EulerDiscreteScheduler, DPMSolverMultistepScheduler + +repo_id = "runwayml/stable-diffusion-v1-5" + +scheduler = EulerDiscreteScheduler.from_pretrained(repo_id, subfolder="scheduler") +# or +# scheduler = DPMSolverMultistepScheduler.from_pretrained(repo_id, subfolder="scheduler") + +stable_diffusion = DiffusionPipeline.from_pretrained(repo_id, scheduler=scheduler) +``` + +Three things are worth paying attention to here. +- First, the scheduler is loaded with [`SchedulerMixin.from_pretrained`] +- Second, the scheduler is loaded with a function argument, called `subfolder="scheduler"` as the configuration of stable diffusion's scheduling is defined in a [subfolder of the official pipeline repository](https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main/scheduler) +- Third, the scheduler instance can simply be passed with the `scheduler` keyword argument to [`DiffusionPipeline.from_pretrained`]. This works because the [`StableDiffusionPipeline`] defines its scheduler with the `scheduler` attribute. It's not possible to use a different name, such as `sampler=scheduler` since `sampler` is not a defined keyword for [`StableDiffusionPipeline.__init__`] + +Not only the scheduler components can be customized for diffusion pipelines; in theory, all components of a pipeline can be customized. In practice, however, it often only makes sense to switch out a component that has **compatible** alternatives to what the pipeline expects. +Many scheduler classes are compatible with each other as can be seen [here](https://github.com/huggingface/diffusers/blob/0dd8c6b4dbab4069de9ed1cafb53cbd495873879/src/diffusers/schedulers/scheduling_ddim.py#L112). This is not always the case for other components, such as the `"unet"`. + +One special case that can also be customized is the `"safety_checker"` of stable diffusion. If you believe the safety checker doesn't serve you any good, you can simply disable it by passing `None`: + +```python +from diffusers import DiffusionPipeline, EulerDiscreteScheduler, DPMSolverMultistepScheduler + +stable_diffusion = DiffusionPipeline.from_pretrained(repo_id, safety_checker=None) +``` + +Another common use case is to reuse the same components in multiple pipelines, *e.g.* the weights and configurations of [`"runwayml/stable-diffusion-v1-5"`](https://huggingface.co/runwayml/stable-diffusion-v1-5) can be used for both [`StableDiffusionPipeline`] and [`StableDiffusionImg2ImgPipeline`] and we might not want to +use the exact same weights into RAM twice. In this case, customizing all the input instances would help us +to only load the weights into RAM once: + +```python +from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline + +model_id = "runwayml/stable-diffusion-v1-5" +stable_diffusion_txt2img = StableDiffusionPipeline.from_pretrained(model_id) + +components = stable_diffusion_txt2img.components + +# weights are not reloaded into RAM +stable_diffusion_img2img = StableDiffusionImg2ImgPipeline(**components) +``` + +Note how the above code snippet makes use of [`DiffusionPipeline.components`]. + +### How does loading work? + +As a class method, [`DiffusionPipeline.from_pretrained`] is responsible for two things: +- Download the latest version of the folder structure required to run the `repo_id` with `diffusers` and cache them. If the latest folder structure is available in the local cache, [`DiffusionPipeline.from_pretrained`] will simply reuse the cache and **not** re-download the files. +- Load the cached weights into the _correct_ pipeline class – one of the [officially supported pipeline classes](./api/overview#diffusers-summary) - and return an instance of the class. The _correct_ pipeline class is thereby retrieved from the `model_index.json` file. + +The underlying folder structure of diffusion pipelines correspond 1-to-1 to their corresponding class instances, *e.g.* [`LDMTextToImagePipeline`] for [`CompVis/ldm-text2im-large-256`](https://huggingface.co/CompVis/ldm-text2im-large-256) +This can be understood better by looking at an example. Let's print out pipeline class instance `pipeline` we just defined: + +```python +from diffusers import DiffusionPipeline + +repo_id = "CompVis/ldm-text2im-large-256" +ldm = DiffusionPipeline.from_pretrained(repo_id) +print(ldm) +``` + +*Output*: +``` +LDMTextToImagePipeline { + "bert": [ + "latent_diffusion", + "LDMBertModel" + ], + "scheduler": [ + "diffusers", + "DDIMScheduler" + ], + "tokenizer": [ + "transformers", + "BertTokenizer" + ], + "unet": [ + "diffusers", + "UNet2DConditionModel" + ], + "vqvae": [ + "diffusers", + "AutoencoderKL" + ] +} +``` + +First, we see that the official pipeline is the [`LDMTextToImagePipeline`], and second we see that the `LDMTextToImagePipeline` consists of 5 components: +- `"bert"` of class `LDMBertModel` as defined [in the pipeline](https://github.com/huggingface/diffusers/blob/cd502b25cf0debac6f98d27a6638ef95208d1ea2/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py#L664) +- `"scheduler"` of class [`DDIMScheduler`] +- `"tokenizer"` of class `BertTokenizer` as defined [in `transformers`](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer) +- `"unet"` of class [`UNet2DConditionModel`] +- `"vqvae"` of class [`AutoencoderKL`] + +Let's now compare the pipeline instance to the folder structure of the model repository `CompVis/ldm-text2im-large-256`. Looking at the folder structure of [`CompVis/ldm-text2im-large-256`](https://huggingface.co/CompVis/ldm-text2im-large-256/tree/main) on the Hub, we can see it matches 1-to-1 the printed out instance of `LDMTextToImagePipeline` above: + +``` +. +├── bert +│   ├── config.json +│   └── pytorch_model.bin +├── model_index.json +├── scheduler +│   └── scheduler_config.json +├── tokenizer +│   ├── special_tokens_map.json +│   ├── tokenizer_config.json +│   └── vocab.txt +├── unet +│   ├── config.json +│   └── diffusion_pytorch_model.bin +└── vqvae + ├── config.json + └── diffusion_pytorch_model.bin +``` + +As we can see each attribute of the instance of `LDMTextToImagePipeline` has its configuration and possibly weights defined in a subfolder that is called **exactly** like the class attribute (`"bert"`, `"scheduler"`, `"tokenizer"`, `"unet"`, `"vqvae"`). Importantly, every pipeline expects a `model_index.json` file that tells the `DiffusionPipeline` both: +- which pipeline class should be loaded, and +- what sub-classes from which library are stored in which subfolders + +In the case of `CompVis/ldm-text2im-large-256` the `model_index.json` is therefore defined as follows: + +``` +{ + "_class_name": "LDMTextToImagePipeline", + "_diffusers_version": "0.0.4", + "bert": [ + "latent_diffusion", + "LDMBertModel" + ], + "scheduler": [ + "diffusers", + "DDIMScheduler" + ], + "tokenizer": [ + "transformers", + "BertTokenizer" + ], + "unet": [ + "diffusers", + "UNet2DConditionModel" + ], + "vqvae": [ + "diffusers", + "AutoencoderKL" + ] +} +``` + +- `_class_name` tells `DiffusionPipeline` which pipeline class should be loaded. +- `_diffusers_version` can be useful to know under which `diffusers` version this model was created. +- Every component of the pipeline is then defined under the form: +``` +"name" : [ + "library", + "class" +] +``` + - The `"name"` field corresponds both to the name of the subfolder in which the configuration and weights are stored as well as the attribute name of the pipeline class (as can be seen [here](https://huggingface.co/CompVis/ldm-text2im-large-256/tree/main/bert) and [here](https://github.com/huggingface/diffusers/blob/cd502b25cf0debac6f98d27a6638ef95208d1ea2/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py#L42) + - The `"library"` field corresponds to the name of the library, *e.g.* `diffusers` or `transformers` from which the `"class"` should be loaded + - The `"class"` field corresponds to the name of the class, *e.g.* [`BertTokenizer`](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer) or [`UNet2DConditionModel`] + + +## Loading models + +Models as defined under [src/diffusers/models](https://github.com/huggingface/diffusers/tree/main/src/diffusers/models) can be loaded via the [`ModelMixin.from_pretrained`] function. The API is very similar the [`DiffusionPipeline.from_pretrained`] and works in the same way: +- Download the latest version of the model weights and configuration with `diffusers` and cache them. If the latest files are available in the local cache, [`ModelMixin.from_pretrained`] will simply reuse the cache and **not** re-download the files. +- Load the cached weights into the _defined_ model class - one of [the existing model classes](./api/models) - and return an instance of the class. + +In constrast to [`DiffusionPipeline.from_pretrained`], models rely on fewer files that usually don't require a folder structure, but just a `diffusion_pytorch_model.bin` and `config.json` file. + +Let's look at an example: + +```python +from diffusers import UNet2DConditionModel + +repo_id = "CompVis/ldm-text2im-large-256" +model = UNet2DConditionModel.from_pretrained(repo_id, subfolder="unet") +``` + +Note how we have to define the `subfolder="unet"` argument to tell [`ModelMixin.from_pretrained`] that the model weights are located in a [subfolder of the repository](https://huggingface.co/CompVis/ldm-text2im-large-256/tree/main/unet). + +As explained in [Loading customized pipelines]("./using-diffusers/loading#loading-customized-pipelines"), one can pass a loaded model to a diffusion pipeline, via [`DiffusionPipeline.from_pretrained`]: + +```python +from diffusers import DiffusionPipeline + +repo_id = "CompVis/ldm-text2im-large-256" +ldm = DiffusionPipeline.from_pretrained(repo_id, unet=model) +``` + +If the model files can be found directly at the root level, which is usually only the case for some very simple diffusion models, such as [`google/ddpm-cifar10-32`](https://huggingface.co/google/ddpm-cifar10-32), we don't +need to pass a `subfolder` argument: + +```python +from diffusers import UNet2DModel + +repo_id = "google/ddpm-cifar10-32" +model = UNet2DModel.from_pretrained(repo_id) +``` + +## Loading schedulers + +Schedulers rely on [`SchedulerMixin.from_pretrained`]. Schedulers are **not parameterized** or **trained**, but instead purely defined by a configuration file. +For consistency, we use the same method name as we do for models or pipelines, but no weights are loaded in this case. + +In constrast to pipelines or models, loading schedulers does not consume any significant amount of memory and the same configuration file can often be used for a variety of different schedulers. +For example, all of: + +- [`DDPMScheduler`] +- [`DDIMScheduler`] +- [`PNDMScheduler`] +- [`LMSDiscreteScheduler`] +- [`EulerDiscreteScheduler`] +- [`EulerAncestralDiscreteScheduler`] +- [`DPMSolverMultistepScheduler`] + +are compatible with [`StableDiffusionPipeline`] and therefore the same scheduler configuration file can be loaded in any of those classes: + +```python +from diffusers import StableDiffusionPipeline +from diffusers import ( + DDPMScheduler, + DDIMScheduler, + PNDMScheduler, + LMSDiscreteScheduler, + EulerDiscreteScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, +) + +repo_id = "runwayml/stable-diffusion-v1-5" + +ddpm = DDPMScheduler.from_pretrained(repo_id, subfolder="scheduler") +ddim = DDIMScheduler.from_pretrained(repo_id, subfolder="scheduler") +pndm = PNDMScheduler.from_pretrained(repo_id, subfolder="scheduler") +lms = LMSDiscreteScheduler.from_pretrained(repo_id, subfolder="scheduler") +euler_anc = EulerAncestralDiscreteScheduler.from_pretrained(repo_id, subfolder="scheduler") +euler = EulerDiscreteScheduler.from_pretrained(repo_id, subfolder="scheduler") +dpm = DPMSolverMultistepScheduler.from_pretrained(repo_id, subfolder="scheduler") + +# replace `dpm` with any of `ddpm`, `ddim`, `pndm`, `lms`, `euler`, `euler_anc` +pipeline = StableDiffusionPipeline.from_pretrained(repo_id, scheduler=dpm) +``` + +## API + +[[autodoc]] modeling_utils.ModelMixin + - from_pretrained + - save_pretrained + +[[autodoc]] pipeline_utils.DiffusionPipeline + - from_pretrained + - save_pretrained + +[[autodoc]] modeling_flax_utils.FlaxModelMixin + - from_pretrained + - save_pretrained + +[[autodoc]] pipeline_flax_utils.FlaxDiffusionPipeline + - from_pretrained + - save_pretrained diff --git a/docs/source/using-diffusers/other-modalities.mdx b/docs/source/using-diffusers/other-modalities.mdx new file mode 100644 index 000000000000..1dc0877adb24 --- /dev/null +++ b/docs/source/using-diffusers/other-modalities.mdx @@ -0,0 +1,20 @@ + + +# Using Diffusers with other modalities + +Diffusers is in the process of expanding to modalities other than images. + +Currently, one example is for [molecule conformation](https://www.nature.com/subjects/molecular-conformation#:~:text=Definition,to%20changes%20in%20their%20environment.) generation. +* Generate conformations in Colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/geodiff_molecule_conformation.ipynb) + +More coming soon! \ No newline at end of file diff --git a/docs/source/using-diffusers/rl.mdx b/docs/source/using-diffusers/rl.mdx new file mode 100644 index 000000000000..6e18e07001b6 --- /dev/null +++ b/docs/source/using-diffusers/rl.mdx @@ -0,0 +1,18 @@ + + +# Using Diffusers for reinforcement learning + +Support for one RL model and related pipelines is included in the `experimental` source of diffusers. + +To try some of this in colab, please look at the following example: +* Model-based reinforcement learning on Colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/reinforcement_learning_with_diffusers.ipynb) ![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg) diff --git a/docs/source/using-diffusers/schedulers.mdx b/docs/source/using-diffusers/schedulers.mdx new file mode 100644 index 000000000000..87ff789747a3 --- /dev/null +++ b/docs/source/using-diffusers/schedulers.mdx @@ -0,0 +1,262 @@ + + +# Schedulers + +Diffusion pipelines are inherently a collection of diffusion models and schedulers that are partly independent from each other. This means that one is able to switch out parts of the pipeline to better customize +a pipeline to one's use case. The best example of this are the [Schedulers](../api/schedulers.mdx). + +Whereas diffusion models usually simply define the forward pass from noise to a less noisy sample, +schedulers define the whole denoising process, *i.e.*: +- How many denoising steps? +- Stochastic or deterministic? +- What algorithm to use to find the denoised sample + +They can be quite complex and often define a trade-off between **denoising speed** and **denoising quality**. +It is extremely difficult to measure quantitatively which scheduler works best for a given diffusion pipeline, so it is often recommended to simply try out which works best. + +The following paragraphs shows how to do so with the 🧨 Diffusers library. + +## Load pipeline + +Let's start by loading the stable diffusion pipeline. +Remember that you have to be a registered user on the 🤗 Hugging Face Hub, and have "click-accepted" the [license](https://huggingface.co/runwayml/stable-diffusion-v1-5) in order to use stable diffusion. + +```python +from huggingface_hub import login +from diffusers import DiffusionPipeline +import torch + +# first we need to login with our access token +login() + +# Now we can download the pipeline +pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16) +``` + +Next, we move it to GPU: + +```python +pipeline.to("cuda") +``` + +## Access the scheduler + +The scheduler is always one of the components of the pipeline and is usually called `"scheduler"`. +So it can be accessed via the `"scheduler"` property. + +```python +pipeline.scheduler +``` + +**Output**: +``` +PNDMScheduler { + "_class_name": "PNDMScheduler", + "_diffusers_version": "0.8.0.dev0", + "beta_end": 0.012, + "beta_schedule": "scaled_linear", + "beta_start": 0.00085, + "clip_sample": false, + "num_train_timesteps": 1000, + "set_alpha_to_one": false, + "skip_prk_steps": true, + "steps_offset": 1, + "trained_betas": null +} +``` + +We can see that the scheduler is of type [`PNDMScheduler`]. +Cool, now let's compare the scheduler in its performance to other schedulers. +First we define a prompt on which we will test all the different schedulers: + +```python +prompt = "A photograph of an astronaut riding a horse on Mars, high resolution, high definition." +``` + +Next, we create a generator from a random seed that will ensure that we can generate similar images as well as run the pipeline: + +```python +generator = torch.Generator(device="cuda").manual_seed(8) +image = pipeline(prompt, generator=generator).images[0] +image +``` + +

+
+ +
+

+ + +## Changing the scheduler + +Now we show how easy it is to change the scheduler of a pipeline. Every scheduler has a property [`SchedulerMixin.compatibles`] +which defines all compatible schedulers. You can take a look at all available, compatible schedulers for the Stable Diffusion pipeline as follows. + +```python +pipeline.scheduler.compatibles +``` + +**Output**: +``` +[diffusers.schedulers.scheduling_lms_discrete.LMSDiscreteScheduler, + diffusers.schedulers.scheduling_ddim.DDIMScheduler, + diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler, + diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler, + diffusers.schedulers.scheduling_pndm.PNDMScheduler, + diffusers.schedulers.scheduling_ddpm.DDPMScheduler, + diffusers.schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteScheduler] +``` + +Cool, lots of schedulers to look at. Feel free to have a look at their respective class definitions: + +- [`LMSDiscreteScheduler`], +- [`DDIMScheduler`], +- [`DPMSolverMultistepScheduler`], +- [`EulerDiscreteScheduler`], +- [`PNDMScheduler`], +- [`DDPMScheduler`], +- [`EulerAncestralDiscreteScheduler`]. + +We will now compare the input prompt with all other schedulers. To change the scheduler of the pipeline you can make use of the +convenient [`ConfigMixin.config`] property in combination with the [`ConfigMixin.from_config`] function. + +```python +pipeline.scheduler.config +``` + +returns a dictionary of the configuration of the scheduler: + +**Output**: +``` +FrozenDict([('num_train_timesteps', 1000), + ('beta_start', 0.00085), + ('beta_end', 0.012), + ('beta_schedule', 'scaled_linear'), + ('trained_betas', None), + ('skip_prk_steps', True), + ('set_alpha_to_one', False), + ('steps_offset', 1), + ('_class_name', 'PNDMScheduler'), + ('_diffusers_version', '0.8.0.dev0'), + ('clip_sample', False)]) +``` + +This configuration can then be used to instantiate a scheduler +of a different class that is compatible with the pipeline. Here, +we change the scheduler to the [`DDIMScheduler`]. + +```python +from diffusers import DDIMScheduler + +pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) +``` + +Cool, now we can run the pipeline again to compare the generation quality. + +```python +generator = torch.Generator(device="cuda").manual_seed(8) +image = pipeline(prompt, generator=generator).images[0] +image +``` + +

+
+ +
+

+ + +## Compare schedulers + +So far we have tried running the stable diffusion pipeline with two schedulers: [`PNDMScheduler`] and [`DDIMScheduler`]. +A number of better schedulers have been released that can be run with much fewer steps, let's compare them here: + +[`LMSDiscreteScheduler`] usually leads to better results: + +```python +from diffusers import LMSDiscreteScheduler + +pipeline.scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config) + +generator = torch.Generator(device="cuda").manual_seed(8) +image = pipeline(prompt, generator=generator).images[0] +image +``` + +

+
+ +
+

+ + +[`EulerDiscreteScheduler`] and [`EulerAncestralDiscreteScheduler`] can generate high quality results with as little as 30 steps. + +```python +from diffusers import EulerDiscreteScheduler + +pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config) + +generator = torch.Generator(device="cuda").manual_seed(8) +image = pipeline(prompt, generator=generator, num_inference_steps=30).images[0] +image +``` + +

+
+ +
+

+ + +and: + +```python +from diffusers import EulerAncestralDiscreteScheduler + +pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config) + +generator = torch.Generator(device="cuda").manual_seed(8) +image = pipeline(prompt, generator=generator, num_inference_steps=30).images[0] +image +``` + +

+
+ +
+

+ + +At the time of writing this doc [`DPMSolverMultistepScheduler`] gives arguably the best speed/quality trade-off and can be run with as little +as 20 steps. + +```python +from diffusers import DPMSolverMultistepScheduler + +pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) + +generator = torch.Generator(device="cuda").manual_seed(8) +image = pipeline(prompt, generator=generator, num_inference_steps=20).images[0] +image +``` + +

+
+ +
+

+ +As you can see most images look very similar and are arguably of very similar quality. It often really depends on the specific use case which scheduler to choose. A good approach is always to run multiple different +schedulers to compare results. diff --git a/examples/README.md b/examples/README.md index 2680b638d585..06ce06b9e3bc 100644 --- a/examples/README.md +++ b/examples/README.md @@ -16,7 +16,7 @@ limitations under the License. # 🧨 Diffusers Examples Diffusers examples are a collection of scripts to demonstrate how to effectively use the `diffusers` library -for a variety of use cases. +for a variety of use cases involving training or fine-tuning. **Note**: If you are looking for **official** examples on how to use `diffusers` for inference, please have a look at [src/diffusers/pipelines](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines) @@ -38,7 +38,11 @@ Training examples show how to pretrain or fine-tune diffusion models for a varie | Task | 🤗 Accelerate | 🤗 Datasets | Colab |---|---|:---:|:---:| -| [**Unconditional Image Generation**](https://github.com/huggingface/diffusers/blob/main/examples/unconditional_image_generation/train_unconditional.py) | ✅ | ✅ | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb) +| [**Unconditional Image Generation**](./unconditional_image_generation) | ✅ | ✅ | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb) +| [**Text-to-Image fine-tuning**](./text_to_image) | ✅ | ✅ | +| [**Textual Inversion**](./textual_inversion) | ✅ | - | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_textual_inversion_training.ipynb) +| [**Dreambooth**](./dreambooth) | ✅ | - | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_dreambooth_training.ipynb) +| [**Reinforcement Learning for Control**](https://github.com/huggingface/diffusers/blob/main/examples/rl/run_diffusers_locomotion.py) | - | - | coming soon. ## Community diff --git a/examples/community/README.md b/examples/community/README.md index 9edbad7d8b6f..108f6f95f1b4 100644 --- a/examples/community/README.md +++ b/examples/community/README.md @@ -6,16 +6,29 @@ Please have a look at the following table to get an overview of all community examples. Click on the **Code Example** to get a copy-and-paste ready code example that you can try out. If a community doesn't work as expected, please open an issue and ping the author on it. -| Example | Description | Code Example | Colab | Author | -|:----------|:----------------------|:-----------------|:-------------|----------:| -| CLIP Guided Stable Diffusion | Doing CLIP guidance for text to image generation with Stable Diffusion| [CLIP Guided Stable Diffusion](#clip-guided-stable-diffusion) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/CLIP_Guided_Stable_diffusion_with_diffusers.ipynb) | [Suraj Patil](https://github.com/patil-suraj/) | -| One Step U-Net (Dummy) | Example showcasing of how to use Community Pipelines (see https://github.com/huggingface/diffusers/issues/841) | [One Step U-Net](#one-step-unet) | - | [Patrick von Platen](https://github.com/patrickvonplaten/) | -| Stable Diffusion Interpolation | Interpolate the latent space of Stable Diffusion between different prompts/seeds | [Stable Diffusion Interpolation](#stable-diffusion-interpolation) | - | [Nate Raw](https://github.com/nateraw/) | -| Stable Diffusion Mega | **One** Stable Diffusion Pipeline with all functionalities of [Text2Image](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py), [Image2Image](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py) and [Inpainting](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py) | [Stable Diffusion Mega](#stable-diffusion-mega) | - | [Patrick von Platen](https://github.com/patrickvonplaten/) | +| Example | Description | Code Example | Colab | Author | +|:---------------------------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------------------------------------------------------------------|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------:| +| CLIP Guided Stable Diffusion | Doing CLIP guidance for text to image generation with Stable Diffusion | [CLIP Guided Stable Diffusion](#clip-guided-stable-diffusion) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/CLIP_Guided_Stable_diffusion_with_diffusers.ipynb) | [Suraj Patil](https://github.com/patil-suraj/) | +| One Step U-Net (Dummy) | Example showcasing of how to use Community Pipelines (see https://github.com/huggingface/diffusers/issues/841) | [One Step U-Net](#one-step-unet) | - | [Patrick von Platen](https://github.com/patrickvonplaten/) | +| Stable Diffusion Interpolation | Interpolate the latent space of Stable Diffusion between different prompts/seeds | [Stable Diffusion Interpolation](#stable-diffusion-interpolation) | - | [Nate Raw](https://github.com/nateraw/) | +| Stable Diffusion Mega | **One** Stable Diffusion Pipeline with all functionalities of [Text2Image](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py), [Image2Image](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py) and [Inpainting](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py) | [Stable Diffusion Mega](#stable-diffusion-mega) | - | [Patrick von Platen](https://github.com/patrickvonplaten/) | +| Long Prompt Weighting Stable Diffusion | **One** Stable Diffusion Pipeline without tokens length limit, and support parsing weighting in prompt. | [Long Prompt Weighting Stable Diffusion](#long-prompt-weighting-stable-diffusion) | - | [SkyTNT](https://github.com/SkyTNT) | +| Speech to Image | Using automatic-speech-recognition to transcribe text and Stable Diffusion to generate images | [Speech to Image](#speech-to-image) | - | [Mikail Duzenli](https://github.com/MikailINTech) +| Wild Card Stable Diffusion | Stable Diffusion Pipeline that supports prompts that contain wildcard terms (indicated by surrounding double underscores), with values instantiated randomly from a corresponding txt file or a dictionary of possible values | [Wildcard Stable Diffusion](#wildcard-stable-diffusion) | - | [Shyam Sudhakaran](https://github.com/shyamsn97) | +| [Composable Stable Diffusion](https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/) | Stable Diffusion Pipeline that supports prompts that contain "|" in prompts (as an AND condition) and weights (separated by "|" as well) to positively / negatively weight prompts. | [Composable Stable Diffusion](#composable-stable-diffusion) | - | [Mark Rich](https://github.com/MarkRich) | +| Seed Resizing Stable Diffusion| Stable Diffusion Pipeline that supports resizing an image and retaining the concepts of the 512 by 512 generation. | [Seed Resizing](#seed-resizing) | - | [Mark Rich](https://github.com/MarkRich) | +| Imagic Stable Diffusion | Stable Diffusion Pipeline that enables writing a text prompt to edit an existing image| [Imagic Stable Diffusion](#imagic-stable-diffusion) | - | [Mark Rich](https://github.com/MarkRich) | +| Multilingual Stable Diffusion| Stable Diffusion Pipeline that supports prompts in 50 different languages. | [Multilingual Stable Diffusion](#multilingual-stable-diffusion-pipeline) | - | [Juan Carlos Piñeros](https://github.com/juancopi81) | +| Image to Image Inpainting Stable Diffusion | Stable Diffusion Pipeline that enables the overlaying of two images and subsequent inpainting| [Image to Image Inpainting Stable Diffusion](#image-to-image-inpainting-stable-diffusion) | - | [Alex McKinney](https://github.com/vvvm23) | +| Text Based Inpainting Stable Diffusion | Stable Diffusion Inpainting Pipeline that enables passing a text prompt to generate the mask for inpainting| [Text Based Inpainting Stable Diffusion](#image-to-image-inpainting-stable-diffusion) | - | [Dhruv Karan](https://github.com/unography) | +| Bit Diffusion | Diffusion on discrete data | [Bit Diffusion](#bit-diffusion) | - |[Stuti R.](https://github.com/kingstut) | +| K-Diffusion Stable Diffusion | Run Stable Diffusion with any of [K-Diffusion's samplers](https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py) | [Stable Diffusion with K Diffusion](#stable-diffusion-with-k-diffusion) | - | [Patrick von Platen](https://github.com/patrickvonplaten/) | + + To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly. ```py -pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom_pipeline="filename_in_the_community_folder") +pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", custom_pipeline="filename_in_the_community_folder") ``` ## Example usages @@ -38,7 +51,7 @@ clip_model = CLIPModel.from_pretrained("laion/CLIP-ViT-B-32-laion2B-s34B-b79K", guided_pipeline = DiffusionPipeline.from_pretrained( - "CompVis/stable-diffusion-v1-4", + "runwayml/stable-diffusion-v1-5", custom_pipeline="clip_guided_stable_diffusion", clip_model=clip_model, feature_extractor=feature_extractor, @@ -138,7 +151,7 @@ def download_image(url): response = requests.get(url) return PIL.Image.open(BytesIO(response.content)).convert("RGB") -pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom_pipeline="stable_diffusion_mega", torch_dtype=torch.float16, revision="fp16") +pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", custom_pipeline="stable_diffusion_mega", torch_dtype=torch.float16, revision="fp16") pipe.to("cuda") pipe.enable_attention_slicing() @@ -168,3 +181,548 @@ images = pipe.inpaint(prompt=prompt, init_image=init_image, mask_image=mask_imag As shown above this one pipeline can run all both "text-to-image", "image-to-image", and "inpainting" in one pipeline. +### Long Prompt Weighting Stable Diffusion +Features of this custom pipeline: +- Input a prompt without the 77 token length limit. +- Includes tx2img, img2img. and inpainting pipelines. +- Emphasize/weigh part of your prompt with parentheses as so: `a baby deer with (big eyes)` +- De-emphasize part of your prompt as so: `a [baby] deer with big eyes` +- Precisely weigh part of your prompt as so: `a baby deer with (big eyes:1.3)` + +Prompt weighting equivalents: +- `a baby deer with` == `(a baby deer with:1.0)` +- `(big eyes)` == `(big eyes:1.1)` +- `((big eyes))` == `(big eyes:1.21)` +- `[big eyes]` == `(big eyes:0.91)` + +You can run this custom pipeline as so: + +#### pytorch + +```python +from diffusers import DiffusionPipeline +import torch + +pipe = DiffusionPipeline.from_pretrained( + 'hakurei/waifu-diffusion', + custom_pipeline="lpw_stable_diffusion", + revision="fp16", + torch_dtype=torch.float16 +) +pipe=pipe.to("cuda") + +prompt = "best_quality (1girl:1.3) bow bride brown_hair closed_mouth frilled_bow frilled_hair_tubes frills (full_body:1.3) fox_ear hair_bow hair_tubes happy hood japanese_clothes kimono long_sleeves red_bow smile solo tabi uchikake white_kimono wide_sleeves cherry_blossoms" +neg_prompt = "lowres, bad_anatomy, error_body, error_hair, error_arm, error_hands, bad_hands, error_fingers, bad_fingers, missing_fingers, error_legs, bad_legs, multiple_legs, missing_legs, error_lighting, error_shadow, error_reflection, text, error, extra_digit, fewer_digits, cropped, worst_quality, low_quality, normal_quality, jpeg_artifacts, signature, watermark, username, blurry" + +pipe.text2img(prompt, negative_prompt=neg_prompt, width=512,height=512,max_embeddings_multiples=3).images[0] + +``` + +#### onnxruntime + +```python +from diffusers import DiffusionPipeline +import torch + +pipe = DiffusionPipeline.from_pretrained( + 'CompVis/stable-diffusion-v1-4', + custom_pipeline="lpw_stable_diffusion_onnx", + revision="onnx", + provider="CUDAExecutionProvider" +) + +prompt = "a photo of an astronaut riding a horse on mars, best quality" +neg_prompt = "lowres, bad anatomy, error body, error hair, error arm, error hands, bad hands, error fingers, bad fingers, missing fingers, error legs, bad legs, multiple legs, missing legs, error lighting, error shadow, error reflection, text, error, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry" + +pipe.text2img(prompt,negative_prompt=neg_prompt, width=512, height=512, max_embeddings_multiples=3).images[0] + +``` + +if you see `Token indices sequence length is longer than the specified maximum sequence length for this model ( *** > 77 ) . Running this sequence through the model will result in indexing errors`. Do not worry, it is normal. + +### Speech to Image + +The following code can generate an image from an audio sample using pre-trained OpenAI whisper-small and Stable Diffusion. + +```Python +import torch + +import matplotlib.pyplot as plt +from datasets import load_dataset +from diffusers import DiffusionPipeline +from transformers import ( + WhisperForConditionalGeneration, + WhisperProcessor, +) + + +device = "cuda" if torch.cuda.is_available() else "cpu" + +ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + +audio_sample = ds[3] + +text = audio_sample["text"].lower() +speech_data = audio_sample["audio"]["array"] + +model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small").to(device) +processor = WhisperProcessor.from_pretrained("openai/whisper-small") + +diffuser_pipeline = DiffusionPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", + custom_pipeline="speech_to_image_diffusion", + speech_model=model, + speech_processor=processor, + revision="fp16", + torch_dtype=torch.float16, +) + +diffuser_pipeline.enable_attention_slicing() +diffuser_pipeline = diffuser_pipeline.to(device) + +output = diffuser_pipeline(speech_data) +plt.imshow(output.images[0]) +``` +This example produces the following image: + +![image](https://user-images.githubusercontent.com/45072645/196901736-77d9c6fc-63ee-4072-90b0-dc8b903d63e3.png) + +### Wildcard Stable Diffusion +Following the great examples from https://github.com/jtkelm2/stable-diffusion-webui-1/blob/master/scripts/wildcards.py and https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Scripts#wildcards, here's a minimal implementation that allows for users to add "wildcards", denoted by `__wildcard__` to prompts that are used as placeholders for randomly sampled values given by either a dictionary or a `.txt` file. For example: + +Say we have a prompt: + +``` +prompt = "__animal__ sitting on a __object__ wearing a __clothing__" +``` + +We can then define possible values to be sampled for `animal`, `object`, and `clothing`. These can either be from a `.txt` with the same name as the category. + +The possible values can also be defined / combined by using a dictionary like: `{"animal":["dog", "cat", mouse"]}`. + +The actual pipeline works just like `StableDiffusionPipeline`, except the `__call__` method takes in: + +`wildcard_files`: list of file paths for wild card replacement +`wildcard_option_dict`: dict with key as `wildcard` and values as a list of possible replacements +`num_prompt_samples`: number of prompts to sample, uniformly sampling wildcards + +A full example: + +create `animal.txt`, with contents like: + +``` +dog +cat +mouse +``` + +create `object.txt`, with contents like: + +``` +chair +sofa +bench +``` + +```python +from diffusers import DiffusionPipeline +import torch + +pipe = DiffusionPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", + custom_pipeline="wildcard_stable_diffusion", + revision="fp16", + torch_dtype=torch.float16, +) +prompt = "__animal__ sitting on a __object__ wearing a __clothing__" +out = pipe( + prompt, + wildcard_option_dict={ + "clothing":["hat", "shirt", "scarf", "beret"] + }, + wildcard_files=["object.txt", "animal.txt"], + num_prompt_samples=1 +) +``` + +### Composable Stable diffusion + +[Composable Stable Diffusion](https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/) proposes conjunction and negation (negative prompts) operators for compositional generation with conditional diffusion models. + +```python +import torch as th +import numpy as np +import torchvision.utils as tvu +from diffusers import DiffusionPipeline + +has_cuda = th.cuda.is_available() +device = th.device('cpu' if not has_cuda else 'cuda') + +pipe = DiffusionPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", + use_auth_token=True, + custom_pipeline="composable_stable_diffusion", +).to(device) + + +def dummy(images, **kwargs): + return images, False + +pipe.safety_checker = dummy + +images = [] +generator = torch.Generator("cuda").manual_seed(0) + +seed = 0 +prompt = "a forest | a camel" +weights = " 1 | 1" # Equal weight to each prompt. Can be negative + +images = [] +for i in range(4): + res = pipe( + prompt, + guidance_scale=7.5, + num_inference_steps=50, + weights=weights, + generator=generator) + image = res.images[0] + images.append(image) + +for i, img in enumerate(images): + img.save(f"./composable_diffusion/image_{i}.png") +``` + +### Imagic Stable Diffusion +Allows you to edit an image using stable diffusion. + +```python +import requests +from PIL import Image +from io import BytesIO +import torch +import os +from diffusers import DiffusionPipeline, DDIMScheduler +has_cuda = torch.cuda.is_available() +device = torch.device('cpu' if not has_cuda else 'cuda') +pipe = DiffusionPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", + safety_checker=None, + use_auth_token=True, + custom_pipeline="imagic_stable_diffusion", + scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False) +).to(device) +generator = th.Generator("cuda").manual_seed(0) +seed = 0 +prompt = "A photo of Barack Obama smiling with a big grin" +url = 'https://www.dropbox.com/s/6tlwzr73jd1r9yk/obama.png?dl=1' +response = requests.get(url) +init_image = Image.open(BytesIO(response.content)).convert("RGB") +init_image = init_image.resize((512, 512)) +res = pipe.train( + prompt, + init_image, + guidance_scale=7.5, + num_inference_steps=50, + generator=generator) +res = pipe(alpha=1) +os.makedirs("imagic", exist_ok=True) +image = res.images[0] +image.save('./imagic/imagic_image_alpha_1.png') +res = pipe(alpha=1.5) +image = res.images[0] +image.save('./imagic/imagic_image_alpha_1_5.png') +res = pipe(alpha=2) +image = res.images[0] +image.save('./imagic/imagic_image_alpha_2.png') +``` + +### Seed Resizing +Test seed resizing. Originally generate an image in 512 by 512, then generate image with same seed at 512 by 592 using seed resizing. Finally, generate 512 by 592 using original stable diffusion pipeline. + +```python +import torch as th +import numpy as np +from diffusers import DiffusionPipeline + +has_cuda = th.cuda.is_available() +device = th.device('cpu' if not has_cuda else 'cuda') + +pipe = DiffusionPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", + use_auth_token=True, + custom_pipeline="seed_resize_stable_diffusion" +).to(device) + +def dummy(images, **kwargs): + return images, False + +pipe.safety_checker = dummy + + +images = [] +th.manual_seed(0) +generator = th.Generator("cuda").manual_seed(0) + +seed = 0 +prompt = "A painting of a futuristic cop" + +width = 512 +height = 512 + +res = pipe( + prompt, + guidance_scale=7.5, + num_inference_steps=50, + height=height, + width=width, + generator=generator) +image = res.images[0] +image.save('./seed_resize/seed_resize_{w}_{h}_image.png'.format(w=width, h=height)) + + +th.manual_seed(0) +generator = th.Generator("cuda").manual_seed(0) + +pipe = DiffusionPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", + use_auth_token=True, + custom_pipeline="/home/mark/open_source/diffusers/examples/community/" +).to(device) + +width = 512 +height = 592 + +res = pipe( + prompt, + guidance_scale=7.5, + num_inference_steps=50, + height=height, + width=width, + generator=generator) +image = res.images[0] +image.save('./seed_resize/seed_resize_{w}_{h}_image.png'.format(w=width, h=height)) + +pipe_compare = DiffusionPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", + use_auth_token=True, + custom_pipeline="/home/mark/open_source/diffusers/examples/community/" +).to(device) + +res = pipe_compare( + prompt, + guidance_scale=7.5, + num_inference_steps=50, + height=height, + width=width, + generator=generator +) + +image = res.images[0] +image.save('./seed_resize/seed_resize_{w}_{h}_image_compare.png'.format(w=width, h=height)) +``` + +### Multilingual Stable Diffusion Pipeline + +The following code can generate an images from texts in different languages using the pre-trained [mBART-50 many-to-one multilingual machine translation model](https://huggingface.co/facebook/mbart-large-50-many-to-one-mmt) and Stable Diffusion. + +```python +from PIL import Image + +import torch + +from diffusers import DiffusionPipeline +from transformers import ( + pipeline, + MBart50TokenizerFast, + MBartForConditionalGeneration, +) +device = "cuda" if torch.cuda.is_available() else "cpu" +device_dict = {"cuda": 0, "cpu": -1} + +# helper function taken from: https://huggingface.co/blog/stable_diffusion +def image_grid(imgs, rows, cols): + assert len(imgs) == rows*cols + + w, h = imgs[0].size + grid = Image.new('RGB', size=(cols*w, rows*h)) + grid_w, grid_h = grid.size + + for i, img in enumerate(imgs): + grid.paste(img, box=(i%cols*w, i//cols*h)) + return grid + +# Add language detection pipeline +language_detection_model_ckpt = "papluca/xlm-roberta-base-language-detection" +language_detection_pipeline = pipeline("text-classification", + model=language_detection_model_ckpt, + device=device_dict[device]) + +# Add model for language translation +trans_tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-one-mmt") +trans_model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-one-mmt").to(device) + +diffuser_pipeline = DiffusionPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", + custom_pipeline="multilingual_stable_diffusion", + detection_pipeline=language_detection_pipeline, + translation_model=trans_model, + translation_tokenizer=trans_tokenizer, + revision="fp16", + torch_dtype=torch.float16, +) + +diffuser_pipeline.enable_attention_slicing() +diffuser_pipeline = diffuser_pipeline.to(device) + +prompt = ["a photograph of an astronaut riding a horse", + "Una casa en la playa", + "Ein Hund, der Orange isst", + "Un restaurant parisien"] + +output = diffuser_pipeline(prompt) + +images = output.images + +grid = image_grid(images, rows=2, cols=2) +``` + +This example produces the following images: +![image](https://user-images.githubusercontent.com/4313860/198328706-295824a4-9856-4ce5-8e66-278ceb42fd29.png) + +### Image to Image Inpainting Stable Diffusion + +Similar to the standard stable diffusion inpainting example, except with the addition of an `inner_image` argument. + +`image`, `inner_image`, and `mask` should have the same dimensions. `inner_image` should have an alpha (transparency) channel. + +The aim is to overlay two images, then mask out the boundary between `image` and `inner_image` to allow stable diffusion to make the connection more seamless. +For example, this could be used to place a logo on a shirt and make it blend seamlessly. + +```python +import PIL +import torch + +from diffusers import StableDiffusionInpaintPipeline + +image_path = "./path-to-image.png" +inner_image_path = "./path-to-inner-image.png" +mask_path = "./path-to-mask.png" + +init_image = PIL.Image.open(image_path).convert("RGB").resize((512, 512)) +inner_image = PIL.Image.open(inner_image_path).convert("RGBA").resize((512, 512)) +mask_image = PIL.Image.open(mask_path).convert("RGB").resize((512, 512)) + +pipe = StableDiffusionInpaintPipeline.from_pretrained( + "runwayml/stable-diffusion-inpainting", + revision="fp16", + torch_dtype=torch.float16, +) +pipe = pipe.to("cuda") + +prompt = "Your prompt here!" +image = pipe(prompt=prompt, image=init_image, inner_image=inner_image, mask_image=mask_image).images[0] +``` + +### Text Based Inpainting Stable Diffusion + +Use a text prompt to generate the mask for the area to be inpainted. +Currently uses the CLIPSeg model for mask generation, then calls the standard Stable Diffusion Inpainting pipeline to perform the inpainting. + +```python +from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation +from diffusers import DiffusionPipeline + +from PIL import Image +import requests +from torch import autocast + +processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") +model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined") + +pipe = DiffusionPipeline.from_pretrained( + "runwayml/stable-diffusion-inpainting", + custom_pipeline="text_inpainting", + segmentation_model=model, + segmentation_processor=processor +) +pipe = pipe.to("cuda") + + +url = "https://github.com/timojl/clipseg/blob/master/example_image.jpg?raw=true" +image = Image.open(requests.get(url, stream=True).raw).resize((512, 512)) +text = "a glass" # will mask out this text +prompt = "a cup" # the masked out region will be replaced with this + +with autocast("cuda"): + image = pipe(image=image, text=text, prompt=prompt).images[0] +``` + +### Bit Diffusion +Based https://arxiv.org/abs/2208.04202, this is used for diffusion on discrete data - eg, discreate image data, DNA sequence data. An unconditional discreate image can be generated like this: + +```python +from diffusers import DiffusionPipeline +pipe = DiffusionPipeline.from_pretrained("google/ddpm-cifar10-32", custom_pipeline="bit_diffusion") +image = pipe().images[0] + +``` + +### Stable Diffusion with K Diffusion + +Make sure you have @crowsonkb's https://github.com/crowsonkb/k-diffusion installed: + +``` +pip install k-diffusion +``` + +You can use the community pipeline as follows: + +```python +from diffusers import DiffusionPipeline + +pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom_pipeline="sd_text2img_k_diffusion") +pipe = pipe.to("cuda") + +prompt = "an astronaut riding a horse on mars" +pipe.set_sampler("sample_heun") +generator = torch.Generator(device="cuda").manual_seed(seed) +image = pipe(prompt, generator=generator, num_inference_steps=20).images[0] + +image.save("./astronaut_heun_k_diffusion.png") +``` + +To make sure that K Diffusion and `diffusers` yield the same results: + +**Diffusers**: +```python +from diffusers import DiffusionPipeline, EulerDiscreteScheduler + +seed = 33 + +pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4") +pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config) +pipe = pipe.to("cuda") + +generator = torch.Generator(device="cuda").manual_seed(seed) +image = pipe(prompt, generator=generator, num_inference_steps=50).images[0] +``` + +![diffusers_euler](https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/k_diffusion/astronaut_euler.png) + +**K Diffusion**: +```python +from diffusers import DiffusionPipeline, EulerDiscreteScheduler + +seed = 33 + +pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom_pipeline="sd_text2img_k_diffusion") +pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config) +pipe = pipe.to("cuda") + +pipe.set_sampler("sample_euler") +generator = torch.Generator(device="cuda").manual_seed(seed) +image = pipe(prompt, generator=generator, num_inference_steps=50).images[0] +``` + +![diffusers_euler](https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/k_diffusion/astronaut_euler_k_diffusion.png) + diff --git a/examples/community/bit_diffusion.py b/examples/community/bit_diffusion.py new file mode 100644 index 000000000000..956e25a7e5c5 --- /dev/null +++ b/examples/community/bit_diffusion.py @@ -0,0 +1,265 @@ +from typing import Optional, Tuple, Union + +import torch + +from diffusers import DDIMScheduler, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel +from diffusers.pipeline_utils import ImagePipelineOutput +from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput +from diffusers.schedulers.scheduling_ddpm import DDPMSchedulerOutput +from einops import rearrange, reduce + + +BITS = 8 + + +# convert to bit representations and back taken from https://github.com/lucidrains/bit-diffusion/blob/main/bit_diffusion/bit_diffusion.py +def decimal_to_bits(x, bits=BITS): + """expects image tensor ranging from 0 to 1, outputs bit tensor ranging from -1 to 1""" + device = x.device + + x = (x * 255).int().clamp(0, 255) + + mask = 2 ** torch.arange(bits - 1, -1, -1, device=device) + mask = rearrange(mask, "d -> d 1 1") + x = rearrange(x, "b c h w -> b c 1 h w") + + bits = ((x & mask) != 0).float() + bits = rearrange(bits, "b c d h w -> b (c d) h w") + bits = bits * 2 - 1 + return bits + + +def bits_to_decimal(x, bits=BITS): + """expects bits from -1 to 1, outputs image tensor from 0 to 1""" + device = x.device + + x = (x > 0).int() + mask = 2 ** torch.arange(bits - 1, -1, -1, device=device, dtype=torch.int32) + + mask = rearrange(mask, "d -> d 1 1") + x = rearrange(x, "b (c d) h w -> b c d h w", d=8) + dec = reduce(x * mask, "b c d h w -> b c h w", "sum") + return (dec / 255).clamp(0.0, 1.0) + + +# modified scheduler step functions for clamping the predicted x_0 between -bit_scale and +bit_scale +def ddim_bit_scheduler_step( + self, + model_output: torch.FloatTensor, + timestep: int, + sample: torch.FloatTensor, + eta: float = 0.0, + use_clipped_model_output: bool = True, + generator=None, + return_dict: bool = True, +) -> Union[DDIMSchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + eta (`float`): weight of noise for added noise in diffusion step. + use_clipped_model_output (`bool`): TODO + generator: random number generator. + return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class + Returns: + [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf + # Ideally, read DDIM paper in-detail understanding + + # Notation ( -> + # - pred_noise_t -> e_theta(x_t, t) + # - pred_original_sample -> f_theta(x_t, t) or x_0 + # - std_dev_t -> sigma_t + # - eta -> η + # - pred_sample_direction -> "direction pointing to x_t" + # - pred_prev_sample -> "x_t-1" + + # 1. get previous step value (=t-1) + prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps + + # 2. compute alphas, betas + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + + beta_prod_t = 1 - alpha_prod_t + + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + + # 4. Clip "predicted x_0" + scale = self.bit_scale + if self.config.clip_sample: + pred_original_sample = torch.clamp(pred_original_sample, -scale, scale) + + # 5. compute variance: "sigma_t(η)" -> see formula (16) + # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) + variance = self._get_variance(timestep, prev_timestep) + std_dev_t = eta * variance ** (0.5) + + if use_clipped_model_output: + # the model_output is always re-derived from the clipped x_0 in Glide + model_output = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + + # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output + + # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction + + if eta > 0: + # randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072 + device = model_output.device if torch.is_tensor(model_output) else "cpu" + noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator).to(device) + variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * noise + + prev_sample = prev_sample + variance + + if not return_dict: + return (prev_sample,) + + return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) + + +def ddpm_bit_scheduler_step( + self, + model_output: torch.FloatTensor, + timestep: int, + sample: torch.FloatTensor, + prediction_type="epsilon", + generator=None, + return_dict: bool = True, +) -> Union[DDPMSchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + prediction_type (`str`, default `epsilon`): + indicates whether the model predicts the noise (epsilon), or the samples (`sample`). + generator: random number generator. + return_dict (`bool`): option for returning tuple rather than DDPMSchedulerOutput class + Returns: + [`~schedulers.scheduling_utils.DDPMSchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.DDPMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + t = timestep + + if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: + model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) + else: + predicted_variance = None + + # 1. compute alphas, betas + alpha_prod_t = self.alphas_cumprod[t] + alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + # 2. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf + if prediction_type == "epsilon": + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + elif prediction_type == "sample": + pred_original_sample = model_output + else: + raise ValueError(f"Unsupported prediction_type {prediction_type}.") + + # 3. Clip "predicted x_0" + scale = self.bit_scale + if self.config.clip_sample: + pred_original_sample = torch.clamp(pred_original_sample, -scale, scale) + + # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * self.betas[t]) / beta_prod_t + current_sample_coeff = self.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t + + # 5. Compute predicted previous sample µ_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample + + # 6. Add noise + variance = 0 + if t > 0: + noise = torch.randn( + model_output.size(), dtype=model_output.dtype, layout=model_output.layout, generator=generator + ).to(model_output.device) + variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * noise + + pred_prev_sample = pred_prev_sample + variance + + if not return_dict: + return (pred_prev_sample,) + + return DDPMSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample) + + +class BitDiffusion(DiffusionPipeline): + def __init__( + self, + unet: UNet2DConditionModel, + scheduler: Union[DDIMScheduler, DDPMScheduler], + bit_scale: Optional[float] = 1.0, + ): + super().__init__() + self.bit_scale = bit_scale + self.scheduler.step = ( + ddim_bit_scheduler_step if isinstance(scheduler, DDIMScheduler) else ddpm_bit_scheduler_step + ) + + self.register_modules(unet=unet, scheduler=scheduler) + + @torch.no_grad() + def __call__( + self, + height: Optional[int] = 256, + width: Optional[int] = 256, + num_inference_steps: Optional[int] = 50, + generator: Optional[torch.Generator] = None, + batch_size: Optional[int] = 1, + output_type: Optional[str] = "pil", + return_dict: bool = True, + **kwargs, + ) -> Union[Tuple, ImagePipelineOutput]: + latents = torch.randn( + (batch_size, self.unet.in_channels, height, width), + generator=generator, + ) + latents = decimal_to_bits(latents) * self.bit_scale + latents = latents.to(self.device) + + self.scheduler.set_timesteps(num_inference_steps) + + for t in self.progress_bar(self.scheduler.timesteps): + # predict the noise residual + noise_pred = self.unet(latents, t).sample + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents).prev_sample + + image = bits_to_decimal(latents) + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/examples/community/clip_guided_stable_diffusion.py b/examples/community/clip_guided_stable_diffusion.py index 974f4ab2e883..7a319bddf053 100644 --- a/examples/community/clip_guided_stable_diffusion.py +++ b/examples/community/clip_guided_stable_diffusion.py @@ -5,7 +5,14 @@ from torch import nn from torch.nn import functional as F -from diffusers import AutoencoderKL, DiffusionPipeline, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel +from diffusers import ( + AutoencoderKL, + DDIMScheduler, + DiffusionPipeline, + LMSDiscreteScheduler, + PNDMScheduler, + UNet2DConditionModel, +) from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput from torchvision import transforms from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextModel, CLIPTokenizer @@ -56,7 +63,7 @@ def __init__( clip_model: CLIPModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, - scheduler: Union[PNDMScheduler, LMSDiscreteScheduler], + scheduler: Union[PNDMScheduler, LMSDiscreteScheduler, DDIMScheduler], feature_extractor: CLIPFeatureExtractor, ): super().__init__() @@ -71,7 +78,12 @@ def __init__( ) self.normalize = transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std) - self.make_cutouts = MakeCutouts(feature_extractor.size) + cut_out_size = ( + feature_extractor.size + if isinstance(feature_extractor.size, int) + else feature_extractor.size["shortest_edge"] + ) + self.make_cutouts = MakeCutouts(cut_out_size) set_requires_grad(self.text_encoder, False) set_requires_grad(self.clip_model, False) @@ -123,7 +135,7 @@ def cond_fn( # predict the noise residual noise_pred = self.unet(latent_model_input, timestep, encoder_hidden_states=text_embeddings).sample - if isinstance(self.scheduler, PNDMScheduler): + if isinstance(self.scheduler, (PNDMScheduler, DDIMScheduler)): alpha_prod_t = self.scheduler.alphas_cumprod[timestep] beta_prod_t = 1 - alpha_prod_t # compute predicted original sample from predicted noise also called @@ -176,6 +188,7 @@ def __call__( num_inference_steps: Optional[int] = 50, guidance_scale: Optional[float] = 7.5, num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, clip_guidance_scale: Optional[float] = 100, clip_prompt: Optional[Union[str, List[str]]] = None, num_cutouts: Optional[int] = 4, @@ -249,7 +262,7 @@ def __call__( latents_dtype = text_embeddings.dtype if latents is None: if self.device.type == "mps": - # randn does not exist on mps + # randn does not work reproducibly on mps latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to( self.device ) @@ -275,6 +288,20 @@ def __call__( # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + for i, t in enumerate(self.progress_bar(timesteps_tensor)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents @@ -306,7 +333,7 @@ def __call__( ) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents).prev_sample + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # scale and decode the image latents with vae latents = 1 / 0.18215 * latents diff --git a/examples/community/composable_stable_diffusion.py b/examples/community/composable_stable_diffusion.py new file mode 100644 index 000000000000..eb207e1bdd47 --- /dev/null +++ b/examples/community/composable_stable_diffusion.py @@ -0,0 +1,329 @@ +""" + modified based on diffusion library from Huggingface: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +""" +import inspect +import warnings +from typing import List, Optional, Union + +import torch + +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer + + +class ComposableStableDiffusionPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using Stable Diffusion. + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offsensive or harmful. + Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPFeatureExtractor, + ): + super().__init__() + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): + r""" + Enable sliced attention computation. + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + Args: + slice_size (`str` or `int`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, + `attention_head_dim` must be a multiple of `slice_size`. + """ + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = self.unet.config.attention_head_dim // 2 + self.unet.set_attention_slice(slice_size) + + def disable_attention_slicing(self): + r""" + Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go + back to computing attention in one step. + """ + # set slice_size = `None` to disable `attention slicing` + self.enable_attention_slicing(None) + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + height: Optional[int] = 512, + width: Optional[int] = 512, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + eta: Optional[float] = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + weights: Optional[str] = "", + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + + if "torch_device" in kwargs: + device = kwargs.pop("torch_device") + warnings.warn( + "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0." + " Consider using `pipe.to(torch_device)` instead." + ) + + # Set device as before (to be removed in 0.3.0) + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + self.to(device) + + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if "|" in prompt: + prompt = [x.strip() for x in prompt.split("|")] + print(f"composing {prompt}...") + + # get prompt text embeddings + text_input = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] + + if not weights: + # specify weights for prompts (excluding the unconditional score) + print("using equal weights for all prompts...") + pos_weights = torch.tensor( + [1 / (text_embeddings.shape[0] - 1)] * (text_embeddings.shape[0] - 1), device=self.device + ).reshape(-1, 1, 1, 1) + neg_weights = torch.tensor([1.0], device=self.device).reshape(-1, 1, 1, 1) + mask = torch.tensor([False] + [True] * pos_weights.shape[0], dtype=torch.bool) + else: + # set prompt weight for each + num_prompts = len(prompt) if isinstance(prompt, list) else 1 + weights = [float(w.strip()) for w in weights.split("|")] + if len(weights) < num_prompts: + weights.append(1.0) + weights = torch.tensor(weights, device=self.device) + assert len(weights) == text_embeddings.shape[0], "weights specified are not equal to the number of prompts" + pos_weights = [] + neg_weights = [] + mask = [] # first one is unconditional score + for w in weights: + if w > 0: + pos_weights.append(w) + mask.append(True) + else: + neg_weights.append(abs(w)) + mask.append(False) + # normalize the weights + pos_weights = torch.tensor(pos_weights, device=self.device).reshape(-1, 1, 1, 1) + pos_weights = pos_weights / pos_weights.sum() + neg_weights = torch.tensor(neg_weights, device=self.device).reshape(-1, 1, 1, 1) + neg_weights = neg_weights / neg_weights.sum() + mask = torch.tensor(mask, device=self.device, dtype=torch.bool) + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + max_length = text_input.input_ids.shape[-1] + + if torch.all(mask): + # no negative prompts, so we use empty string as the negative prompt + uncond_input = self.tokenizer( + [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" + ) + uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + # update negative weights + neg_weights = torch.tensor([1.0], device=self.device) + mask = torch.tensor([False] + mask.detach().tolist(), device=self.device, dtype=torch.bool) + + # get the initial random noise unless the user supplied it + + # Unlike in other pipelines, latents need to be generated in the target device + # for 1-to-1 results reproducibility with the CompVis implementation. + # However this currently doesn't work in `mps`. + latents_device = "cpu" if self.device.type == "mps" else self.device + latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) + if latents is None: + latents = torch.randn( + latents_shape, + generator=generator, + device=latents_device, + ) + else: + if latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + latents = latents.to(self.device) + + # set timesteps + accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys()) + extra_set_kwargs = {} + if accepts_offset: + extra_set_kwargs["offset"] = 1 + + self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) + + # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas + if isinstance(self.scheduler, LMSDiscreteScheduler): + latents = latents * self.scheduler.sigmas[0] + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = ( + torch.cat([latents] * text_embeddings.shape[0]) if do_classifier_free_guidance else latents + ) + if isinstance(self.scheduler, LMSDiscreteScheduler): + sigma = self.scheduler.sigmas[i] + # the model input needs to be scaled to match the continuous ODE formulation in K-LMS + latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) + + # reduce memory by predicting each score sequentially + noise_preds = [] + # predict the noise residual + for latent_in, text_embedding_in in zip( + torch.chunk(latent_model_input, chunks=latent_model_input.shape[0], dim=0), + torch.chunk(text_embeddings, chunks=text_embeddings.shape[0], dim=0), + ): + noise_preds.append(self.unet(latent_in, t, encoder_hidden_states=text_embedding_in).sample) + noise_preds = torch.cat(noise_preds, dim=0) + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond = (noise_preds[~mask] * neg_weights).sum(dim=0, keepdims=True) + noise_pred_text = (noise_preds[mask] * pos_weights).sum(dim=0, keepdims=True) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + if isinstance(self.scheduler, LMSDiscreteScheduler): + latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample + else: + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # scale and decode the image latents with vae + latents = 1 / 0.18215 * latents + image = self.vae.decode(latents).sample + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + + # run safety checker + safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) + image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values) + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/examples/community/imagic_stable_diffusion.py b/examples/community/imagic_stable_diffusion.py new file mode 100644 index 000000000000..65966b4830e8 --- /dev/null +++ b/examples/community/imagic_stable_diffusion.py @@ -0,0 +1,497 @@ +""" + modeled after the textual_inversion.py / train_dreambooth.py and the work + of justinpinkney here: https://github.com/justinpinkney/stable-diffusion/blob/main/notebooks/imagic.ipynb +""" +import inspect +import warnings +from typing import List, Optional, Union + +import numpy as np +import torch +import torch.nn.functional as F + +import PIL +from accelerate import Accelerator +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from diffusers.utils import logging + +# TODO: remove and import from diffusers.utils when the new version of diffusers is released +from packaging import version +from tqdm.auto import tqdm +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer + + +if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): + PIL_INTERPOLATION = { + "linear": PIL.Image.Resampling.BILINEAR, + "bilinear": PIL.Image.Resampling.BILINEAR, + "bicubic": PIL.Image.Resampling.BICUBIC, + "lanczos": PIL.Image.Resampling.LANCZOS, + "nearest": PIL.Image.Resampling.NEAREST, + } +else: + PIL_INTERPOLATION = { + "linear": PIL.Image.LINEAR, + "bilinear": PIL.Image.BILINEAR, + "bicubic": PIL.Image.BICUBIC, + "lanczos": PIL.Image.LANCZOS, + "nearest": PIL.Image.NEAREST, + } +# ------------------------------------------------------------------------------ + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def preprocess(image): + w, h = image.size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + return 2.0 * image - 1.0 + + +class ImagicStableDiffusionPipeline(DiffusionPipeline): + r""" + Pipeline for imagic image editing. + See paper here: https://arxiv.org/pdf/2210.09276.pdf + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offsensive or harmful. + Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPFeatureExtractor, + ): + super().__init__() + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): + r""" + Enable sliced attention computation. + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + Args: + slice_size (`str` or `int`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, + `attention_head_dim` must be a multiple of `slice_size`. + """ + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = self.unet.config.attention_head_dim // 2 + self.unet.set_attention_slice(slice_size) + + def disable_attention_slicing(self): + r""" + Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go + back to computing attention in one step. + """ + # set slice_size = `None` to disable `attention slicing` + self.enable_attention_slicing(None) + + def train( + self, + prompt: Union[str, List[str]], + init_image: Union[torch.FloatTensor, PIL.Image.Image], + height: Optional[int] = 512, + width: Optional[int] = 512, + generator: Optional[torch.Generator] = None, + embedding_learning_rate: float = 0.001, + diffusion_model_learning_rate: float = 2e-6, + text_embedding_optimization_steps: int = 500, + model_fine_tuning_optimization_steps: int = 1000, + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + accelerator = Accelerator( + gradient_accumulation_steps=1, + mixed_precision="fp16", + ) + + if "torch_device" in kwargs: + device = kwargs.pop("torch_device") + warnings.warn( + "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0." + " Consider using `pipe.to(torch_device)` instead." + ) + + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + self.to(device) + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + # Freeze vae and unet + self.vae.requires_grad_(False) + self.unet.requires_grad_(False) + self.text_encoder.requires_grad_(False) + self.unet.eval() + self.vae.eval() + self.text_encoder.eval() + + if accelerator.is_main_process: + accelerator.init_trackers( + "imagic", + config={ + "embedding_learning_rate": embedding_learning_rate, + "text_embedding_optimization_steps": text_embedding_optimization_steps, + }, + ) + + # get text embeddings for prompt + text_input = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncaton=True, + return_tensors="pt", + ) + text_embeddings = torch.nn.Parameter( + self.text_encoder(text_input.input_ids.to(self.device))[0], requires_grad=True + ) + text_embeddings = text_embeddings.detach() + text_embeddings.requires_grad_() + text_embeddings_orig = text_embeddings.clone() + + # Initialize the optimizer + optimizer = torch.optim.Adam( + [text_embeddings], # only optimize the embeddings + lr=embedding_learning_rate, + ) + + if isinstance(init_image, PIL.Image.Image): + init_image = preprocess(init_image) + + latents_dtype = text_embeddings.dtype + init_image = init_image.to(device=self.device, dtype=latents_dtype) + init_latent_image_dist = self.vae.encode(init_image).latent_dist + init_image_latents = init_latent_image_dist.sample(generator=generator) + init_image_latents = 0.18215 * init_image_latents + + progress_bar = tqdm(range(text_embedding_optimization_steps), disable=not accelerator.is_local_main_process) + progress_bar.set_description("Steps") + + global_step = 0 + + logger.info("First optimizing the text embedding to better reconstruct the init image") + for _ in range(text_embedding_optimization_steps): + with accelerator.accumulate(text_embeddings): + # Sample noise that we'll add to the latents + noise = torch.randn(init_image_latents.shape).to(init_image_latents.device) + timesteps = torch.randint(1000, (1,), device=init_image_latents.device) + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = self.scheduler.add_noise(init_image_latents, noise, timesteps) + + # Predict the noise residual + noise_pred = self.unet(noisy_latents, timesteps, text_embeddings).sample + + loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() + accelerator.backward(loss) + + optimizer.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + logs = {"loss": loss.detach().item()} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + accelerator.wait_for_everyone() + + text_embeddings.requires_grad_(False) + + # Now we fine tune the unet to better reconstruct the image + self.unet.requires_grad_(True) + self.unet.train() + optimizer = torch.optim.Adam( + self.unet.parameters(), # only optimize unet + lr=diffusion_model_learning_rate, + ) + progress_bar = tqdm(range(model_fine_tuning_optimization_steps), disable=not accelerator.is_local_main_process) + + logger.info("Next fine tuning the entire model to better reconstruct the init image") + for _ in range(model_fine_tuning_optimization_steps): + with accelerator.accumulate(self.unet.parameters()): + # Sample noise that we'll add to the latents + noise = torch.randn(init_image_latents.shape).to(init_image_latents.device) + timesteps = torch.randint(1000, (1,), device=init_image_latents.device) + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = self.scheduler.add_noise(init_image_latents, noise, timesteps) + + # Predict the noise residual + noise_pred = self.unet(noisy_latents, timesteps, text_embeddings).sample + + loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() + accelerator.backward(loss) + + optimizer.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + logs = {"loss": loss.detach().item()} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + accelerator.wait_for_everyone() + self.text_embeddings_orig = text_embeddings_orig + self.text_embeddings = text_embeddings + + @torch.no_grad() + def __call__( + self, + alpha: float = 1.2, + height: Optional[int] = 512, + width: Optional[int] = 512, + num_inference_steps: Optional[int] = 50, + generator: Optional[torch.Generator] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + guidance_scale: float = 7.5, + eta: float = 0.0, + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + if self.text_embeddings is None: + raise ValueError("Please run the pipe.train() before trying to generate an image.") + if self.text_embeddings_orig is None: + raise ValueError("Please run the pipe.train() before trying to generate an image.") + + text_embeddings = alpha * self.text_embeddings_orig + (1 - alpha) * self.text_embeddings + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_tokens = [""] + max_length = self.tokenizer.model_max_length + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = uncond_embeddings.shape[1] + uncond_embeddings = uncond_embeddings.view(1, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + # get the initial random noise unless the user supplied it + + # Unlike in other pipelines, latents need to be generated in the target device + # for 1-to-1 results reproducibility with the CompVis implementation. + # However this currently doesn't work in `mps`. + latents_shape = (1, self.unet.in_channels, height // 8, width // 8) + latents_dtype = text_embeddings.dtype + if self.device.type == "mps": + # randn does not exist on mps + latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to( + self.device + ) + else: + latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype) + + # set timesteps + self.scheduler.set_timesteps(num_inference_steps) + + # Some schedulers like PNDM have timesteps as arrays + # It's more optimized to move all timesteps to correct device beforehand + timesteps_tensor = self.scheduler.timesteps.to(self.device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + for i, t in enumerate(self.progress_bar(timesteps_tensor)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + latents = 1 / 0.18215 * latents + image = self.vae.decode(latents).sample + + image = (image / 2 + 0.5).clamp(0, 1) + + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to( + self.device + ) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype) + ) + else: + has_nsfw_concept = None + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/examples/community/img2img_inpainting.py b/examples/community/img2img_inpainting.py new file mode 100644 index 000000000000..3fa7db13a482 --- /dev/null +++ b/examples/community/img2img_inpainting.py @@ -0,0 +1,463 @@ +import inspect +from typing import Callable, List, Optional, Tuple, Union + +import numpy as np +import torch + +import PIL +from diffusers.configuration_utils import FrozenDict +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from diffusers.utils import deprecate, logging +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def prepare_mask_and_masked_image(image, mask): + image = np.array(image.convert("RGB")) + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 + + mask = np.array(mask.convert("L")) + mask = mask.astype(np.float32) / 255.0 + mask = mask[None, None] + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + mask = torch.from_numpy(mask) + + masked_image = image * (mask < 0.5) + + return mask, masked_image + + +def check_size(image, height, width): + if isinstance(image, PIL.Image.Image): + w, h = image.size + elif isinstance(image, torch.Tensor): + *_, h, w = image.shape + + if h != height or w != width: + raise ValueError(f"Image size should be {height}x{width}, but got {h}x{w}") + + +def overlay_inner_image(image, inner_image, paste_offset: Tuple[int] = (0, 0)): + inner_image = inner_image.convert("RGBA") + image = image.convert("RGB") + + image.paste(inner_image, paste_offset, inner_image) + image = image.convert("RGB") + + return image + + +class ImageToImageInpaintingPipeline(DiffusionPipeline): + r""" + Pipeline for text-guided image-to-image inpainting using Stable Diffusion. *This is an experimental feature*. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPFeatureExtractor, + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, + `attention_head_dim` must be a multiple of `slice_size`. + """ + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = self.unet.config.attention_head_dim // 2 + self.unet.set_attention_slice(slice_size) + + def disable_attention_slicing(self): + r""" + Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go + back to computing attention in one step. + """ + # set slice_size = `None` to disable `attention slicing` + self.enable_attention_slicing(None) + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + image: Union[torch.FloatTensor, PIL.Image.Image], + inner_image: Union[torch.FloatTensor, PIL.Image.Image], + mask_image: Union[torch.FloatTensor, PIL.Image.Image], + height: int = 512, + width: int = 512, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + image (`torch.Tensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will + be masked out with `mask_image` and repainted according to `prompt`. + inner_image (`torch.Tensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch which will be overlayed onto `image`. Non-transparent + regions of `inner_image` must fit inside white pixels in `mask_image`. Expects four channels, with + the last channel representing the alpha channel, which will be used to blend `inner_image` with + `image`. If not provided, it will be forcibly cast to RGBA. + mask_image (`PIL.Image.Image`): + `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be + repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted + to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) + instead of 3, so the expected shape would be `(B, H, W, 1)`. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + # check if input sizes are correct + check_size(image, height, width) + check_size(inner_image, height, width) + check_size(mask_image, height, width) + + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + + if text_input_ids.shape[-1] > self.tokenizer.model_max_length: + removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] + text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] + + # duplicate text embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = text_embeddings.shape + text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) + text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = uncond_embeddings.shape[1] + uncond_embeddings = uncond_embeddings.repeat(batch_size, num_images_per_prompt, 1) + uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + # get the initial random noise unless the user supplied it + # Unlike in other pipelines, latents need to be generated in the target device + # for 1-to-1 results reproducibility with the CompVis implementation. + # However this currently doesn't work in `mps`. + num_channels_latents = self.vae.config.latent_channels + latents_shape = (batch_size * num_images_per_prompt, num_channels_latents, height // 8, width // 8) + latents_dtype = text_embeddings.dtype + if latents is None: + if self.device.type == "mps": + # randn does not exist on mps + latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to( + self.device + ) + else: + latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype) + else: + if latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + latents = latents.to(self.device) + + # overlay the inner image + image = overlay_inner_image(image, inner_image) + + # prepare mask and masked_image + mask, masked_image = prepare_mask_and_masked_image(image, mask_image) + mask = mask.to(device=self.device, dtype=text_embeddings.dtype) + masked_image = masked_image.to(device=self.device, dtype=text_embeddings.dtype) + + # resize the mask to latents shape as we concatenate the mask to the latents + mask = torch.nn.functional.interpolate(mask, size=(height // 8, width // 8)) + + # encode the mask image into latents space so we can concatenate it to the latents + masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator) + masked_image_latents = 0.18215 * masked_image_latents + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + mask = mask.repeat(batch_size * num_images_per_prompt, 1, 1, 1) + masked_image_latents = masked_image_latents.repeat(batch_size * num_images_per_prompt, 1, 1, 1) + + mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask + masked_image_latents = ( + torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents + ) + + num_channels_mask = mask.shape[1] + num_channels_masked_image = masked_image_latents.shape[1] + + if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels: + raise ValueError( + f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" + f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" + f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" + " `pipeline.unet` or your `mask_image` or `image` input." + ) + + # set timesteps + self.scheduler.set_timesteps(num_inference_steps) + + # Some schedulers like PNDM have timesteps as arrays + # It's more optimized to move all timesteps to correct device beforehand + timesteps_tensor = self.scheduler.timesteps.to(self.device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + for i, t in enumerate(self.progress_bar(timesteps_tensor)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + # concat latents, mask, masked_image_latents in the channel dimension + latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) + + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + latents = 1 / 0.18215 * latents + image = self.vae.decode(latents).sample + + image = (image / 2 + 0.5).clamp(0, 1) + + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to( + self.device + ) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype) + ) + else: + has_nsfw_concept = None + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/examples/community/interpolate_stable_diffusion.py b/examples/community/interpolate_stable_diffusion.py index 97116bdc77b4..4d7a73f5ba69 100644 --- a/examples/community/interpolate_stable_diffusion.py +++ b/examples/community/interpolate_stable_diffusion.py @@ -65,7 +65,7 @@ class StableDiffusionWalkPipeline(DiffusionPipeline): [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. @@ -101,7 +101,7 @@ def __init__( scheduler._internal_dict = FrozenDict(new_config) if safety_checker is None: - logger.warn( + logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" @@ -278,7 +278,7 @@ def __call__( if do_classifier_free_guidance: uncond_tokens: List[str] if negative_prompt is None: - uncond_tokens = [""] + uncond_tokens = [""] * batch_size elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" @@ -307,7 +307,7 @@ def __call__( # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = uncond_embeddings.shape[1] - uncond_embeddings = uncond_embeddings.repeat(batch_size, num_images_per_prompt, 1) + uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. @@ -324,7 +324,7 @@ def __call__( latents_dtype = text_embeddings.dtype if latents is None: if self.device.type == "mps": - # randn does not exist on mps + # randn does not work reproducibly on mps latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to( self.device ) diff --git a/examples/community/lpw_stable_diffusion.py b/examples/community/lpw_stable_diffusion.py new file mode 100644 index 000000000000..0e7dc9e1ed11 --- /dev/null +++ b/examples/community/lpw_stable_diffusion.py @@ -0,0 +1,1148 @@ +import inspect +import re +from typing import Callable, List, Optional, Union + +import numpy as np +import torch + +import PIL +from diffusers.configuration_utils import FrozenDict +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from diffusers.utils import deprecate, is_accelerate_available, logging + +# TODO: remove and import from diffusers.utils when the new version of diffusers is released +from packaging import version +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer + + +if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): + PIL_INTERPOLATION = { + "linear": PIL.Image.Resampling.BILINEAR, + "bilinear": PIL.Image.Resampling.BILINEAR, + "bicubic": PIL.Image.Resampling.BICUBIC, + "lanczos": PIL.Image.Resampling.LANCZOS, + "nearest": PIL.Image.Resampling.NEAREST, + } +else: + PIL_INTERPOLATION = { + "linear": PIL.Image.LINEAR, + "bilinear": PIL.Image.BILINEAR, + "bicubic": PIL.Image.BICUBIC, + "lanczos": PIL.Image.LANCZOS, + "nearest": PIL.Image.NEAREST, + } +# ------------------------------------------------------------------------------ + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +re_attention = re.compile( + r""" +\\\(| +\\\)| +\\\[| +\\]| +\\\\| +\\| +\(| +\[| +:([+-]?[.\d]+)\)| +\)| +]| +[^\\()\[\]:]+| +: +""", + re.X, +) + + +def parse_prompt_attention(text): + """ + Parses a string with attention tokens and returns a list of pairs: text and its associated weight. + Accepted tokens are: + (abc) - increases attention to abc by a multiplier of 1.1 + (abc:3.12) - increases attention to abc by a multiplier of 3.12 + [abc] - decreases attention to abc by a multiplier of 1.1 + \( - literal character '(' + \[ - literal character '[' + \) - literal character ')' + \] - literal character ']' + \\ - literal character '\' + anything else - just text + >>> parse_prompt_attention('normal text') + [['normal text', 1.0]] + >>> parse_prompt_attention('an (important) word') + [['an ', 1.0], ['important', 1.1], [' word', 1.0]] + >>> parse_prompt_attention('(unbalanced') + [['unbalanced', 1.1]] + >>> parse_prompt_attention('\(literal\]') + [['(literal]', 1.0]] + >>> parse_prompt_attention('(unnecessary)(parens)') + [['unnecessaryparens', 1.1]] + >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).') + [['a ', 1.0], + ['house', 1.5730000000000004], + [' ', 1.1], + ['on', 1.0], + [' a ', 1.1], + ['hill', 0.55], + [', sun, ', 1.1], + ['sky', 1.4641000000000006], + ['.', 1.1]] + """ + + res = [] + round_brackets = [] + square_brackets = [] + + round_bracket_multiplier = 1.1 + square_bracket_multiplier = 1 / 1.1 + + def multiply_range(start_position, multiplier): + for p in range(start_position, len(res)): + res[p][1] *= multiplier + + for m in re_attention.finditer(text): + text = m.group(0) + weight = m.group(1) + + if text.startswith("\\"): + res.append([text[1:], 1.0]) + elif text == "(": + round_brackets.append(len(res)) + elif text == "[": + square_brackets.append(len(res)) + elif weight is not None and len(round_brackets) > 0: + multiply_range(round_brackets.pop(), float(weight)) + elif text == ")" and len(round_brackets) > 0: + multiply_range(round_brackets.pop(), round_bracket_multiplier) + elif text == "]" and len(square_brackets) > 0: + multiply_range(square_brackets.pop(), square_bracket_multiplier) + else: + res.append([text, 1.0]) + + for pos in round_brackets: + multiply_range(pos, round_bracket_multiplier) + + for pos in square_brackets: + multiply_range(pos, square_bracket_multiplier) + + if len(res) == 0: + res = [["", 1.0]] + + # merge runs of identical weights + i = 0 + while i + 1 < len(res): + if res[i][1] == res[i + 1][1]: + res[i][0] += res[i + 1][0] + res.pop(i + 1) + else: + i += 1 + + return res + + +def get_prompts_with_weights(pipe: DiffusionPipeline, prompt: List[str], max_length: int): + r""" + Tokenize a list of prompts and return its tokens with weights of each token. + + No padding, starting or ending token is included. + """ + tokens = [] + weights = [] + truncated = False + for text in prompt: + texts_and_weights = parse_prompt_attention(text) + text_token = [] + text_weight = [] + for word, weight in texts_and_weights: + # tokenize and discard the starting and the ending token + token = pipe.tokenizer(word).input_ids[1:-1] + text_token += token + # copy the weight by length of token + text_weight += [weight] * len(token) + # stop if the text is too long (longer than truncation limit) + if len(text_token) > max_length: + truncated = True + break + # truncate + if len(text_token) > max_length: + truncated = True + text_token = text_token[:max_length] + text_weight = text_weight[:max_length] + tokens.append(text_token) + weights.append(text_weight) + if truncated: + logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") + return tokens, weights + + +def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77): + r""" + Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length. + """ + max_embeddings_multiples = (max_length - 2) // (chunk_length - 2) + weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length + for i in range(len(tokens)): + tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i])) + if no_boseos_middle: + weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i])) + else: + w = [] + if len(weights[i]) == 0: + w = [1.0] * weights_length + else: + for j in range(max_embeddings_multiples): + w.append(1.0) # weight for starting token in this chunk + w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))] + w.append(1.0) # weight for ending token in this chunk + w += [1.0] * (weights_length - len(w)) + weights[i] = w[:] + + return tokens, weights + + +def get_unweighted_text_embeddings( + pipe: DiffusionPipeline, + text_input: torch.Tensor, + chunk_length: int, + no_boseos_middle: Optional[bool] = True, +): + """ + When the length of tokens is a multiple of the capacity of the text encoder, + it should be split into chunks and sent to the text encoder individually. + """ + max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2) + if max_embeddings_multiples > 1: + text_embeddings = [] + for i in range(max_embeddings_multiples): + # extract the i-th chunk + text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone() + + # cover the head and the tail by the starting and the ending tokens + text_input_chunk[:, 0] = text_input[0, 0] + text_input_chunk[:, -1] = text_input[0, -1] + text_embedding = pipe.text_encoder(text_input_chunk)[0] + + if no_boseos_middle: + if i == 0: + # discard the ending token + text_embedding = text_embedding[:, :-1] + elif i == max_embeddings_multiples - 1: + # discard the starting token + text_embedding = text_embedding[:, 1:] + else: + # discard both starting and ending tokens + text_embedding = text_embedding[:, 1:-1] + + text_embeddings.append(text_embedding) + text_embeddings = torch.concat(text_embeddings, axis=1) + else: + text_embeddings = pipe.text_encoder(text_input)[0] + return text_embeddings + + +def get_weighted_text_embeddings( + pipe: DiffusionPipeline, + prompt: Union[str, List[str]], + uncond_prompt: Optional[Union[str, List[str]]] = None, + max_embeddings_multiples: Optional[int] = 1, + no_boseos_middle: Optional[bool] = False, + skip_parsing: Optional[bool] = False, + skip_weighting: Optional[bool] = False, + **kwargs, +): + r""" + Prompts can be assigned with local weights using brackets. For example, + prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful', + and the embedding tokens corresponding to the words get multiplied by a constant, 1.1. + + Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean. + + Args: + pipe (`DiffusionPipeline`): + Pipe to provide access to the tokenizer and the text encoder. + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + uncond_prompt (`str` or `List[str]`): + The unconditional prompt or prompts for guide the image generation. If unconditional prompt + is provided, the embeddings of prompt and uncond_prompt are concatenated. + max_embeddings_multiples (`int`, *optional*, defaults to `1`): + The max multiple length of prompt embeddings compared to the max output length of text encoder. + no_boseos_middle (`bool`, *optional*, defaults to `False`): + If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and + ending token in each of the chunk in the middle. + skip_parsing (`bool`, *optional*, defaults to `False`): + Skip the parsing of brackets. + skip_weighting (`bool`, *optional*, defaults to `False`): + Skip the weighting. When the parsing is skipped, it is forced True. + """ + max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 + if isinstance(prompt, str): + prompt = [prompt] + + if not skip_parsing: + prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2) + if uncond_prompt is not None: + if isinstance(uncond_prompt, str): + uncond_prompt = [uncond_prompt] + uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2) + else: + prompt_tokens = [ + token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids + ] + prompt_weights = [[1.0] * len(token) for token in prompt_tokens] + if uncond_prompt is not None: + if isinstance(uncond_prompt, str): + uncond_prompt = [uncond_prompt] + uncond_tokens = [ + token[1:-1] + for token in pipe.tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids + ] + uncond_weights = [[1.0] * len(token) for token in uncond_tokens] + + # round up the longest length of tokens to a multiple of (model_max_length - 2) + max_length = max([len(token) for token in prompt_tokens]) + if uncond_prompt is not None: + max_length = max(max_length, max([len(token) for token in uncond_tokens])) + + max_embeddings_multiples = min( + max_embeddings_multiples, + (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1, + ) + max_embeddings_multiples = max(1, max_embeddings_multiples) + max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 + + # pad the length of tokens and weights + bos = pipe.tokenizer.bos_token_id + eos = pipe.tokenizer.eos_token_id + prompt_tokens, prompt_weights = pad_tokens_and_weights( + prompt_tokens, + prompt_weights, + max_length, + bos, + eos, + no_boseos_middle=no_boseos_middle, + chunk_length=pipe.tokenizer.model_max_length, + ) + prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device) + if uncond_prompt is not None: + uncond_tokens, uncond_weights = pad_tokens_and_weights( + uncond_tokens, + uncond_weights, + max_length, + bos, + eos, + no_boseos_middle=no_boseos_middle, + chunk_length=pipe.tokenizer.model_max_length, + ) + uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device) + + # get the embeddings + text_embeddings = get_unweighted_text_embeddings( + pipe, + prompt_tokens, + pipe.tokenizer.model_max_length, + no_boseos_middle=no_boseos_middle, + ) + prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device) + if uncond_prompt is not None: + uncond_embeddings = get_unweighted_text_embeddings( + pipe, + uncond_tokens, + pipe.tokenizer.model_max_length, + no_boseos_middle=no_boseos_middle, + ) + uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device) + + # assign weights to the prompts and normalize in the sense of mean + # TODO: should we normalize by chunk or in a whole (current implementation)? + if (not skip_parsing) and (not skip_weighting): + previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) + text_embeddings *= prompt_weights.unsqueeze(-1) + current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) + text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) + if uncond_prompt is not None: + previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) + uncond_embeddings *= uncond_weights.unsqueeze(-1) + current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) + uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) + + if uncond_prompt is not None: + return text_embeddings, uncond_embeddings + return text_embeddings, None + + +def preprocess_image(image): + w, h = image.size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + return 2.0 * image - 1.0 + + +def preprocess_mask(mask): + mask = mask.convert("L") + w, h = mask.size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + mask = mask.resize((w // 8, h // 8), resample=PIL_INTERPOLATION["nearest"]) + mask = np.array(mask).astype(np.float32) / 255.0 + mask = np.tile(mask, (4, 1, 1)) + mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? + mask = 1 - mask # repaint white, keep black + mask = torch.from_numpy(mask) + return mask + + +class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing + weighting in prompt. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPFeatureExtractor, + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + + def enable_xformers_memory_efficient_attention(self): + r""" + Enable memory efficient attention as implemented in xformers. + + When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference + time. Speed up at training time is not guaranteed. + + Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention + is used. + """ + self.unet.set_use_memory_efficient_attention_xformers(True) + + def disable_xformers_memory_efficient_attention(self): + r""" + Disable memory efficient attention as implemented in xformers. + """ + self.unet.set_use_memory_efficient_attention_xformers(False) + + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, + `attention_head_dim` must be a multiple of `slice_size`. + """ + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = self.unet.config.attention_head_dim // 2 + self.unet.set_attention_slice(slice_size) + + def disable_attention_slicing(self): + r""" + Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go + back to computing attention in one step. + """ + # set slice_size = `None` to disable `attention slicing` + self.enable_attention_slicing(None) + + def enable_sequential_cpu_offload(self): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = self.device + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + init_image: Union[torch.FloatTensor, PIL.Image.Image] = None, + mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None, + height: int = 512, + width: int = 512, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + strength: float = 0.8, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + max_embeddings_multiples: Optional[int] = 3, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + is_cancelled_callback: Optional[Callable[[], bool]] = None, + callback_steps: Optional[int] = 1, + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + init_image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. + mask_image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, to mask `init_image`. White pixels in the mask will be + replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a + PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should + contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `init_image`. Must be between 0 and 1. + `init_image` will be used as a starting point, adding more noise to it the larger the `strength`. The + number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added + noise will be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `init_image`. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + max_embeddings_multiples (`int`, *optional*, defaults to `3`): + The max multiple length of prompt embeddings compared to the max output length of text encoder. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + is_cancelled_callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. If the function returns + `True`, the inference will be cancelled. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Returns: + `None` if cancelled by `is_cancelled_callback`, + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + + if isinstance(prompt, str): + batch_size = 1 + prompt = [prompt] + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + # get prompt text embeddings + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + # get unconditional embeddings for classifier free guidance + if negative_prompt is None: + negative_prompt = [""] * batch_size + elif isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] * batch_size + if batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + text_embeddings, uncond_embeddings = get_weighted_text_embeddings( + pipe=self, + prompt=prompt, + uncond_prompt=negative_prompt if do_classifier_free_guidance else None, + max_embeddings_multiples=max_embeddings_multiples, + **kwargs, + ) + bs_embed, seq_len, _ = text_embeddings.shape + text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) + text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + bs_embed, seq_len, _ = uncond_embeddings.shape + uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) + uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + # set timesteps + self.scheduler.set_timesteps(num_inference_steps) + + latents_dtype = text_embeddings.dtype + init_latents_orig = None + mask = None + noise = None + + if init_image is None: + # get the initial random noise unless the user supplied it + + # Unlike in other pipelines, latents need to be generated in the target device + # for 1-to-1 results reproducibility with the CompVis implementation. + # However this currently doesn't work in `mps`. + latents_shape = ( + batch_size * num_images_per_prompt, + self.unet.in_channels, + height // 8, + width // 8, + ) + + if latents is None: + if self.device.type == "mps": + # randn does not exist on mps + latents = torch.randn( + latents_shape, + generator=generator, + device="cpu", + dtype=latents_dtype, + ).to(self.device) + else: + latents = torch.randn( + latents_shape, + generator=generator, + device=self.device, + dtype=latents_dtype, + ) + else: + if latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + latents = latents.to(self.device) + + timesteps = self.scheduler.timesteps.to(self.device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + else: + if isinstance(init_image, PIL.Image.Image): + init_image = preprocess_image(init_image) + # encode the init image into latents and scale the latents + init_image = init_image.to(device=self.device, dtype=latents_dtype) + init_latent_dist = self.vae.encode(init_image).latent_dist + init_latents = init_latent_dist.sample(generator=generator) + init_latents = 0.18215 * init_latents + init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0) + init_latents_orig = init_latents + + # preprocess mask + if mask_image is not None: + if isinstance(mask_image, PIL.Image.Image): + mask_image = preprocess_mask(mask_image) + mask_image = mask_image.to(device=self.device, dtype=latents_dtype) + mask = torch.cat([mask_image] * batch_size * num_images_per_prompt) + + # check sizes + if not mask.shape == init_latents.shape: + raise ValueError("The mask and init_image should be the same size!") + + # get the original timestep using init_timestep + offset = self.scheduler.config.get("steps_offset", 0) + init_timestep = int(num_inference_steps * strength) + offset + init_timestep = min(init_timestep, num_inference_steps) + + timesteps = self.scheduler.timesteps[-init_timestep] + timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device) + + # add noise to latents using the timesteps + if self.device.type == "mps": + # randn does not exist on mps + noise = torch.randn( + init_latents.shape, + generator=generator, + device="cpu", + dtype=latents_dtype, + ).to(self.device) + else: + noise = torch.randn( + init_latents.shape, + generator=generator, + device=self.device, + dtype=latents_dtype, + ) + latents = self.scheduler.add_noise(init_latents, noise, timesteps) + + t_start = max(num_inference_steps - init_timestep + offset, 0) + timesteps = self.scheduler.timesteps[t_start:].to(self.device) + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + if mask is not None: + # masking + init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t])) + latents = (init_latents_proper * mask) + (latents * (1 - mask)) + + # call the callback, if provided + if i % callback_steps == 0: + if callback is not None: + callback(i, t, latents) + if is_cancelled_callback is not None and is_cancelled_callback(): + return None + + latents = 1 / 0.18215 * latents + image = self.vae.decode(latents).sample + + image = (image / 2 + 0.5).clamp(0, 1) + + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to( + self.device + ) + image, has_nsfw_concept = self.safety_checker( + images=image, + clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype), + ) + else: + has_nsfw_concept = None + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + + def text2img( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 512, + width: int = 512, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + max_embeddings_multiples: Optional[int] = 3, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + **kwargs, + ): + r""" + Function for text-to-image generation. + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + max_embeddings_multiples (`int`, *optional*, defaults to `3`): + The max multiple length of prompt embeddings compared to the max output length of text encoder. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + return self.__call__( + prompt=prompt, + negative_prompt=negative_prompt, + height=height, + width=width, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + num_images_per_prompt=num_images_per_prompt, + eta=eta, + generator=generator, + latents=latents, + max_embeddings_multiples=max_embeddings_multiples, + output_type=output_type, + return_dict=return_dict, + callback=callback, + callback_steps=callback_steps, + **kwargs, + ) + + def img2img( + self, + init_image: Union[torch.FloatTensor, PIL.Image.Image], + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + strength: float = 0.8, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + num_images_per_prompt: Optional[int] = 1, + eta: Optional[float] = 0.0, + generator: Optional[torch.Generator] = None, + max_embeddings_multiples: Optional[int] = 3, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + **kwargs, + ): + r""" + Function for image-to-image generation. + Args: + init_image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `init_image`. Must be between 0 and 1. + `init_image` will be used as a starting point, adding more noise to it the larger the `strength`. The + number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added + noise will be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `init_image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter will be modulated by `strength`. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + max_embeddings_multiples (`int`, *optional*, defaults to `3`): + The max multiple length of prompt embeddings compared to the max output length of text encoder. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + return self.__call__( + prompt=prompt, + negative_prompt=negative_prompt, + init_image=init_image, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + strength=strength, + num_images_per_prompt=num_images_per_prompt, + eta=eta, + generator=generator, + max_embeddings_multiples=max_embeddings_multiples, + output_type=output_type, + return_dict=return_dict, + callback=callback, + callback_steps=callback_steps, + **kwargs, + ) + + def inpaint( + self, + init_image: Union[torch.FloatTensor, PIL.Image.Image], + mask_image: Union[torch.FloatTensor, PIL.Image.Image], + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + strength: float = 0.8, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + num_images_per_prompt: Optional[int] = 1, + eta: Optional[float] = 0.0, + generator: Optional[torch.Generator] = None, + max_embeddings_multiples: Optional[int] = 3, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + **kwargs, + ): + r""" + Function for inpaint. + Args: + init_image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. This is the image whose masked region will be inpainted. + mask_image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, to mask `init_image`. White pixels in the mask will be + replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a + PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should + contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`. + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength` + is 1, the denoising process will be run on the masked area for the full number of iterations specified + in `num_inference_steps`. `init_image` will be used as a reference for the masked area, adding more + noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur. + num_inference_steps (`int`, *optional*, defaults to 50): + The reference number of denoising steps. More denoising steps usually lead to a higher quality image at + the expense of slower inference. This parameter will be modulated by `strength`, as explained above. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + max_embeddings_multiples (`int`, *optional*, defaults to `3`): + The max multiple length of prompt embeddings compared to the max output length of text encoder. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + return self.__call__( + prompt=prompt, + negative_prompt=negative_prompt, + init_image=init_image, + mask_image=mask_image, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + strength=strength, + num_images_per_prompt=num_images_per_prompt, + eta=eta, + generator=generator, + max_embeddings_multiples=max_embeddings_multiples, + output_type=output_type, + return_dict=return_dict, + callback=callback, + callback_steps=callback_steps, + **kwargs, + ) diff --git a/examples/community/lpw_stable_diffusion_onnx.py b/examples/community/lpw_stable_diffusion_onnx.py new file mode 100644 index 000000000000..577772b9c36a --- /dev/null +++ b/examples/community/lpw_stable_diffusion_onnx.py @@ -0,0 +1,1013 @@ +import inspect +import re +from typing import Callable, List, Optional, Union + +import numpy as np +import torch + +import PIL +from diffusers.onnx_utils import OnnxRuntimeModel +from diffusers.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput +from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from diffusers.utils import logging + +# TODO: remove and import from diffusers.utils when the new version of diffusers is released +from packaging import version +from transformers import CLIPFeatureExtractor, CLIPTokenizer + + +if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): + PIL_INTERPOLATION = { + "linear": PIL.Image.Resampling.BILINEAR, + "bilinear": PIL.Image.Resampling.BILINEAR, + "bicubic": PIL.Image.Resampling.BICUBIC, + "lanczos": PIL.Image.Resampling.LANCZOS, + "nearest": PIL.Image.Resampling.NEAREST, + } +else: + PIL_INTERPOLATION = { + "linear": PIL.Image.LINEAR, + "bilinear": PIL.Image.BILINEAR, + "bicubic": PIL.Image.BICUBIC, + "lanczos": PIL.Image.LANCZOS, + "nearest": PIL.Image.NEAREST, + } +# ------------------------------------------------------------------------------ + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +re_attention = re.compile( + r""" +\\\(| +\\\)| +\\\[| +\\]| +\\\\| +\\| +\(| +\[| +:([+-]?[.\d]+)\)| +\)| +]| +[^\\()\[\]:]+| +: +""", + re.X, +) + + +def parse_prompt_attention(text): + """ + Parses a string with attention tokens and returns a list of pairs: text and its associated weight. + Accepted tokens are: + (abc) - increases attention to abc by a multiplier of 1.1 + (abc:3.12) - increases attention to abc by a multiplier of 3.12 + [abc] - decreases attention to abc by a multiplier of 1.1 + \( - literal character '(' + \[ - literal character '[' + \) - literal character ')' + \] - literal character ']' + \\ - literal character '\' + anything else - just text + >>> parse_prompt_attention('normal text') + [['normal text', 1.0]] + >>> parse_prompt_attention('an (important) word') + [['an ', 1.0], ['important', 1.1], [' word', 1.0]] + >>> parse_prompt_attention('(unbalanced') + [['unbalanced', 1.1]] + >>> parse_prompt_attention('\(literal\]') + [['(literal]', 1.0]] + >>> parse_prompt_attention('(unnecessary)(parens)') + [['unnecessaryparens', 1.1]] + >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).') + [['a ', 1.0], + ['house', 1.5730000000000004], + [' ', 1.1], + ['on', 1.0], + [' a ', 1.1], + ['hill', 0.55], + [', sun, ', 1.1], + ['sky', 1.4641000000000006], + ['.', 1.1]] + """ + + res = [] + round_brackets = [] + square_brackets = [] + + round_bracket_multiplier = 1.1 + square_bracket_multiplier = 1 / 1.1 + + def multiply_range(start_position, multiplier): + for p in range(start_position, len(res)): + res[p][1] *= multiplier + + for m in re_attention.finditer(text): + text = m.group(0) + weight = m.group(1) + + if text.startswith("\\"): + res.append([text[1:], 1.0]) + elif text == "(": + round_brackets.append(len(res)) + elif text == "[": + square_brackets.append(len(res)) + elif weight is not None and len(round_brackets) > 0: + multiply_range(round_brackets.pop(), float(weight)) + elif text == ")" and len(round_brackets) > 0: + multiply_range(round_brackets.pop(), round_bracket_multiplier) + elif text == "]" and len(square_brackets) > 0: + multiply_range(square_brackets.pop(), square_bracket_multiplier) + else: + res.append([text, 1.0]) + + for pos in round_brackets: + multiply_range(pos, round_bracket_multiplier) + + for pos in square_brackets: + multiply_range(pos, square_bracket_multiplier) + + if len(res) == 0: + res = [["", 1.0]] + + # merge runs of identical weights + i = 0 + while i + 1 < len(res): + if res[i][1] == res[i + 1][1]: + res[i][0] += res[i + 1][0] + res.pop(i + 1) + else: + i += 1 + + return res + + +def get_prompts_with_weights(pipe, prompt: List[str], max_length: int): + r""" + Tokenize a list of prompts and return its tokens with weights of each token. + + No padding, starting or ending token is included. + """ + tokens = [] + weights = [] + truncated = False + for text in prompt: + texts_and_weights = parse_prompt_attention(text) + text_token = [] + text_weight = [] + for word, weight in texts_and_weights: + # tokenize and discard the starting and the ending token + token = pipe.tokenizer(word, return_tensors="np").input_ids[0, 1:-1] + text_token += list(token) + # copy the weight by length of token + text_weight += [weight] * len(token) + # stop if the text is too long (longer than truncation limit) + if len(text_token) > max_length: + truncated = True + break + # truncate + if len(text_token) > max_length: + truncated = True + text_token = text_token[:max_length] + text_weight = text_weight[:max_length] + tokens.append(text_token) + weights.append(text_weight) + if truncated: + logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") + return tokens, weights + + +def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77): + r""" + Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length. + """ + max_embeddings_multiples = (max_length - 2) // (chunk_length - 2) + weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length + for i in range(len(tokens)): + tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i])) + if no_boseos_middle: + weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i])) + else: + w = [] + if len(weights[i]) == 0: + w = [1.0] * weights_length + else: + for j in range(max_embeddings_multiples): + w.append(1.0) # weight for starting token in this chunk + w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))] + w.append(1.0) # weight for ending token in this chunk + w += [1.0] * (weights_length - len(w)) + weights[i] = w[:] + + return tokens, weights + + +def get_unweighted_text_embeddings( + pipe, + text_input: np.array, + chunk_length: int, + no_boseos_middle: Optional[bool] = True, +): + """ + When the length of tokens is a multiple of the capacity of the text encoder, + it should be split into chunks and sent to the text encoder individually. + """ + max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2) + if max_embeddings_multiples > 1: + text_embeddings = [] + for i in range(max_embeddings_multiples): + # extract the i-th chunk + text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].copy() + + # cover the head and the tail by the starting and the ending tokens + text_input_chunk[:, 0] = text_input[0, 0] + text_input_chunk[:, -1] = text_input[0, -1] + + text_embedding = pipe.text_encoder(input_ids=text_input_chunk)[0] + + if no_boseos_middle: + if i == 0: + # discard the ending token + text_embedding = text_embedding[:, :-1] + elif i == max_embeddings_multiples - 1: + # discard the starting token + text_embedding = text_embedding[:, 1:] + else: + # discard both starting and ending tokens + text_embedding = text_embedding[:, 1:-1] + + text_embeddings.append(text_embedding) + text_embeddings = np.concatenate(text_embeddings, axis=1) + else: + text_embeddings = pipe.text_encoder(input_ids=text_input)[0] + return text_embeddings + + +def get_weighted_text_embeddings( + pipe, + prompt: Union[str, List[str]], + uncond_prompt: Optional[Union[str, List[str]]] = None, + max_embeddings_multiples: Optional[int] = 4, + no_boseos_middle: Optional[bool] = False, + skip_parsing: Optional[bool] = False, + skip_weighting: Optional[bool] = False, + **kwargs, +): + r""" + Prompts can be assigned with local weights using brackets. For example, + prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful', + and the embedding tokens corresponding to the words get multiplied by a constant, 1.1. + + Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean. + + Args: + pipe (`DiffusionPipeline`): + Pipe to provide access to the tokenizer and the text encoder. + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + uncond_prompt (`str` or `List[str]`): + The unconditional prompt or prompts for guide the image generation. If unconditional prompt + is provided, the embeddings of prompt and uncond_prompt are concatenated. + max_embeddings_multiples (`int`, *optional*, defaults to `1`): + The max multiple length of prompt embeddings compared to the max output length of text encoder. + no_boseos_middle (`bool`, *optional*, defaults to `False`): + If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and + ending token in each of the chunk in the middle. + skip_parsing (`bool`, *optional*, defaults to `False`): + Skip the parsing of brackets. + skip_weighting (`bool`, *optional*, defaults to `False`): + Skip the weighting. When the parsing is skipped, it is forced True. + """ + max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 + if isinstance(prompt, str): + prompt = [prompt] + + if not skip_parsing: + prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2) + if uncond_prompt is not None: + if isinstance(uncond_prompt, str): + uncond_prompt = [uncond_prompt] + uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2) + else: + prompt_tokens = [ + token[1:-1] + for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True, return_tensors="np").input_ids + ] + prompt_weights = [[1.0] * len(token) for token in prompt_tokens] + if uncond_prompt is not None: + if isinstance(uncond_prompt, str): + uncond_prompt = [uncond_prompt] + uncond_tokens = [ + token[1:-1] + for token in pipe.tokenizer( + uncond_prompt, + max_length=max_length, + truncation=True, + return_tensors="np", + ).input_ids + ] + uncond_weights = [[1.0] * len(token) for token in uncond_tokens] + + # round up the longest length of tokens to a multiple of (model_max_length - 2) + max_length = max([len(token) for token in prompt_tokens]) + if uncond_prompt is not None: + max_length = max(max_length, max([len(token) for token in uncond_tokens])) + + max_embeddings_multiples = min( + max_embeddings_multiples, + (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1, + ) + max_embeddings_multiples = max(1, max_embeddings_multiples) + max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 + + # pad the length of tokens and weights + bos = pipe.tokenizer.bos_token_id + eos = pipe.tokenizer.eos_token_id + prompt_tokens, prompt_weights = pad_tokens_and_weights( + prompt_tokens, + prompt_weights, + max_length, + bos, + eos, + no_boseos_middle=no_boseos_middle, + chunk_length=pipe.tokenizer.model_max_length, + ) + prompt_tokens = np.array(prompt_tokens, dtype=np.int32) + if uncond_prompt is not None: + uncond_tokens, uncond_weights = pad_tokens_and_weights( + uncond_tokens, + uncond_weights, + max_length, + bos, + eos, + no_boseos_middle=no_boseos_middle, + chunk_length=pipe.tokenizer.model_max_length, + ) + uncond_tokens = np.array(uncond_tokens, dtype=np.int32) + + # get the embeddings + text_embeddings = get_unweighted_text_embeddings( + pipe, + prompt_tokens, + pipe.tokenizer.model_max_length, + no_boseos_middle=no_boseos_middle, + ) + prompt_weights = np.array(prompt_weights, dtype=text_embeddings.dtype) + if uncond_prompt is not None: + uncond_embeddings = get_unweighted_text_embeddings( + pipe, + uncond_tokens, + pipe.tokenizer.model_max_length, + no_boseos_middle=no_boseos_middle, + ) + uncond_weights = np.array(uncond_weights, dtype=uncond_embeddings.dtype) + + # assign weights to the prompts and normalize in the sense of mean + # TODO: should we normalize by chunk or in a whole (current implementation)? + if (not skip_parsing) and (not skip_weighting): + previous_mean = text_embeddings.mean(axis=(-2, -1)) + text_embeddings *= prompt_weights[:, :, None] + text_embeddings *= (previous_mean / text_embeddings.mean(axis=(-2, -1)))[:, None, None] + if uncond_prompt is not None: + previous_mean = uncond_embeddings.mean(axis=(-2, -1)) + uncond_embeddings *= uncond_weights[:, :, None] + uncond_embeddings *= (previous_mean / uncond_embeddings.mean(axis=(-2, -1)))[:, None, None] + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if uncond_prompt is not None: + return text_embeddings, uncond_embeddings + + return text_embeddings + + +def preprocess_image(image): + w, h = image.size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + return 2.0 * image - 1.0 + + +def preprocess_mask(mask): + mask = mask.convert("L") + w, h = mask.size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + mask = mask.resize((w // 8, h // 8), resample=PIL_INTERPOLATION["nearest"]) + mask = np.array(mask).astype(np.float32) / 255.0 + mask = np.tile(mask, (4, 1, 1)) + mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? + mask = 1 - mask # repaint white, keep black + return mask + + +class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing + weighting in prompt. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + """ + + def __init__( + self, + vae_encoder: OnnxRuntimeModel, + vae_decoder: OnnxRuntimeModel, + text_encoder: OnnxRuntimeModel, + tokenizer: CLIPTokenizer, + unet: OnnxRuntimeModel, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + safety_checker: OnnxRuntimeModel, + feature_extractor: CLIPFeatureExtractor, + ): + super().__init__() + self.register_modules( + vae_encoder=vae_encoder, + vae_decoder=vae_decoder, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + init_image: Union[np.ndarray, PIL.Image.Image] = None, + mask_image: Union[np.ndarray, PIL.Image.Image] = None, + height: int = 512, + width: int = 512, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + strength: float = 0.8, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[np.random.RandomState] = None, + latents: Optional[np.ndarray] = None, + max_embeddings_multiples: Optional[int] = 3, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, np.ndarray], None]] = None, + is_cancelled_callback: Optional[Callable[[], bool]] = None, + callback_steps: Optional[int] = 1, + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + init_image (`np.ndarray` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. + mask_image (`np.ndarray` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, to mask `init_image`. White pixels in the mask will be + replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a + PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should + contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `init_image`. Must be between 0 and 1. + `init_image` will be used as a starting point, adding more noise to it the larger the `strength`. The + number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added + noise will be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `init_image`. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`np.random.RandomState`, *optional*): + A np.random.RandomState to make generation deterministic. + latents (`np.ndarray`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + max_embeddings_multiples (`int`, *optional*, defaults to `3`): + The max multiple length of prompt embeddings compared to the max output length of text encoder. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`. + is_cancelled_callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. If the function returns + `True`, the inference will be cancelled. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Returns: + `None` if cancelled by `is_cancelled_callback`, + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + + if isinstance(prompt, str): + batch_size = 1 + prompt = [prompt] + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + # get prompt text embeddings + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + # get unconditional embeddings for classifier free guidance + if negative_prompt is None: + negative_prompt = [""] * batch_size + elif isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] * batch_size + if batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + if generator is None: + generator = np.random + + text_embeddings, uncond_embeddings = get_weighted_text_embeddings( + pipe=self, + prompt=prompt, + uncond_prompt=negative_prompt if do_classifier_free_guidance else None, + max_embeddings_multiples=max_embeddings_multiples, + **kwargs, + ) + + text_embeddings = text_embeddings.repeat(num_images_per_prompt, 0) + if do_classifier_free_guidance: + uncond_embeddings = uncond_embeddings.repeat(num_images_per_prompt, 0) + text_embeddings = np.concatenate([uncond_embeddings, text_embeddings]) + + # set timesteps + self.scheduler.set_timesteps(num_inference_steps) + + latents_dtype = text_embeddings.dtype + init_latents_orig = None + mask = None + noise = None + + if init_image is None: + latents_shape = ( + batch_size * num_images_per_prompt, + 4, + height // 8, + width // 8, + ) + + if latents is None: + latents = generator.randn(*latents_shape).astype(latents_dtype) + elif latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + + timesteps = self.scheduler.timesteps.to(self.device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + else: + if isinstance(init_image, PIL.Image.Image): + init_image = preprocess_image(init_image) + # encode the init image into latents and scale the latents + init_image = init_image.astype(latents_dtype) + init_latents = self.vae_encoder(sample=init_image)[0] + init_latents = 0.18215 * init_latents + init_latents = np.concatenate([init_latents] * batch_size * num_images_per_prompt) + init_latents_orig = init_latents + + # preprocess mask + if mask_image is not None: + if isinstance(mask_image, PIL.Image.Image): + mask_image = preprocess_mask(mask_image) + mask_image = mask_image.astype(latents_dtype) + mask = np.concatenate([mask_image] * batch_size * num_images_per_prompt) + + # check sizes + if not mask.shape == init_latents.shape: + print(mask.shape, init_latents.shape) + raise ValueError("The mask and init_image should be the same size!") + + # get the original timestep using init_timestep + offset = self.scheduler.config.get("steps_offset", 0) + init_timestep = int(num_inference_steps * strength) + offset + init_timestep = min(init_timestep, num_inference_steps) + + timesteps = self.scheduler.timesteps[-init_timestep] + timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt) + + # add noise to latents using the timesteps + noise = generator.randn(*init_latents.shape).astype(latents_dtype) + latents = self.scheduler.add_noise( + torch.from_numpy(init_latents), torch.from_numpy(noise), timesteps + ).numpy() + + t_start = max(num_inference_steps - init_timestep + offset, 0) + timesteps = self.scheduler.timesteps[t_start:] + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + sample=latent_model_input, + timestep=np.array([t]), + encoder_hidden_states=text_embeddings, + ) + noise_pred = noise_pred[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample.numpy() + + if mask is not None: + # masking + init_latents_proper = self.scheduler.add_noise( + torch.from_numpy(init_latents_orig), + torch.from_numpy(noise), + torch.tensor([t]), + ).numpy() + latents = (init_latents_proper * mask) + (latents * (1 - mask)) + + # call the callback, if provided + if i % callback_steps == 0: + if callback is not None: + callback(i, t, latents) + if is_cancelled_callback is not None and is_cancelled_callback(): + return None + + latents = 1 / 0.18215 * latents + # image = self.vae_decoder(latent_sample=latents)[0] + # it seems likes there is a problem for using half-precision vae decoder if batchsize>1 + image = [] + for i in range(latents.shape[0]): + image.append(self.vae_decoder(latent_sample=latents[i : i + 1])[0]) + image = np.concatenate(image) + + image = np.clip(image / 2 + 0.5, 0, 1) + image = image.transpose((0, 2, 3, 1)) + + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor( + self.numpy_to_pil(image), return_tensors="np" + ).pixel_values.astype(image.dtype) + # There will throw an error if use safety_checker directly and batchsize>1 + images, has_nsfw_concept = [], [] + for i in range(image.shape[0]): + image_i, has_nsfw_concept_i = self.safety_checker( + clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1] + ) + images.append(image_i) + has_nsfw_concept.append(has_nsfw_concept_i[0]) + image = np.concatenate(images) + else: + has_nsfw_concept = None + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + + def text2img( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 512, + width: int = 512, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[np.random.RandomState] = None, + latents: Optional[np.ndarray] = None, + max_embeddings_multiples: Optional[int] = 3, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, np.ndarray], None]] = None, + callback_steps: Optional[int] = 1, + **kwargs, + ): + r""" + Function for text-to-image generation. + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`np.random.RandomState`, *optional*): + A np.random.RandomState to make generation deterministic. + latents (`np.ndarray`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + max_embeddings_multiples (`int`, *optional*, defaults to `3`): + The max multiple length of prompt embeddings compared to the max output length of text encoder. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + return self.__call__( + prompt=prompt, + negative_prompt=negative_prompt, + height=height, + width=width, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + num_images_per_prompt=num_images_per_prompt, + eta=eta, + generator=generator, + latents=latents, + max_embeddings_multiples=max_embeddings_multiples, + output_type=output_type, + return_dict=return_dict, + callback=callback, + callback_steps=callback_steps, + **kwargs, + ) + + def img2img( + self, + init_image: Union[np.ndarray, PIL.Image.Image], + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + strength: float = 0.8, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + num_images_per_prompt: Optional[int] = 1, + eta: Optional[float] = 0.0, + generator: Optional[np.random.RandomState] = None, + max_embeddings_multiples: Optional[int] = 3, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, np.ndarray], None]] = None, + callback_steps: Optional[int] = 1, + **kwargs, + ): + r""" + Function for image-to-image generation. + Args: + init_image (`np.ndarray` or `PIL.Image.Image`): + `Image`, or ndarray representing an image batch, that will be used as the starting point for the + process. + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `init_image`. Must be between 0 and 1. + `init_image` will be used as a starting point, adding more noise to it the larger the `strength`. The + number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added + noise will be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `init_image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter will be modulated by `strength`. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`np.random.RandomState`, *optional*): + A np.random.RandomState to make generation deterministic. + max_embeddings_multiples (`int`, *optional*, defaults to `3`): + The max multiple length of prompt embeddings compared to the max output length of text encoder. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + return self.__call__( + prompt=prompt, + negative_prompt=negative_prompt, + init_image=init_image, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + strength=strength, + num_images_per_prompt=num_images_per_prompt, + eta=eta, + generator=generator, + max_embeddings_multiples=max_embeddings_multiples, + output_type=output_type, + return_dict=return_dict, + callback=callback, + callback_steps=callback_steps, + **kwargs, + ) + + def inpaint( + self, + init_image: Union[np.ndarray, PIL.Image.Image], + mask_image: Union[np.ndarray, PIL.Image.Image], + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + strength: float = 0.8, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + num_images_per_prompt: Optional[int] = 1, + eta: Optional[float] = 0.0, + generator: Optional[np.random.RandomState] = None, + max_embeddings_multiples: Optional[int] = 3, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, np.ndarray], None]] = None, + callback_steps: Optional[int] = 1, + **kwargs, + ): + r""" + Function for inpaint. + Args: + init_image (`np.ndarray` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. This is the image whose masked region will be inpainted. + mask_image (`np.ndarray` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, to mask `init_image`. White pixels in the mask will be + replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a + PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should + contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`. + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength` + is 1, the denoising process will be run on the masked area for the full number of iterations specified + in `num_inference_steps`. `init_image` will be used as a reference for the masked area, adding more + noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur. + num_inference_steps (`int`, *optional*, defaults to 50): + The reference number of denoising steps. More denoising steps usually lead to a higher quality image at + the expense of slower inference. This parameter will be modulated by `strength`, as explained above. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`np.random.RandomState`, *optional*): + A np.random.RandomState to make generation deterministic. + max_embeddings_multiples (`int`, *optional*, defaults to `3`): + The max multiple length of prompt embeddings compared to the max output length of text encoder. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + return self.__call__( + prompt=prompt, + negative_prompt=negative_prompt, + init_image=init_image, + mask_image=mask_image, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + strength=strength, + num_images_per_prompt=num_images_per_prompt, + eta=eta, + generator=generator, + max_embeddings_multiples=max_embeddings_multiples, + output_type=output_type, + return_dict=return_dict, + callback=callback, + callback_steps=callback_steps, + **kwargs, + ) diff --git a/examples/community/multilingual_stable_diffusion.py b/examples/community/multilingual_stable_diffusion.py new file mode 100644 index 000000000000..19974d6df08b --- /dev/null +++ b/examples/community/multilingual_stable_diffusion.py @@ -0,0 +1,436 @@ +import inspect +from typing import Callable, List, Optional, Union + +import torch + +from diffusers.configuration_utils import FrozenDict +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from diffusers.utils import deprecate, logging +from transformers import ( + CLIPFeatureExtractor, + CLIPTextModel, + CLIPTokenizer, + MBart50TokenizerFast, + MBartForConditionalGeneration, + pipeline, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def detect_language(pipe, prompt, batch_size): + """helper function to detect language(s) of prompt""" + + if batch_size == 1: + preds = pipe(prompt, top_k=1, truncation=True, max_length=128) + return preds[0]["label"] + else: + detected_languages = [] + for p in prompt: + preds = pipe(p, top_k=1, truncation=True, max_length=128) + detected_languages.append(preds[0]["label"]) + + return detected_languages + + +def translate_prompt(prompt, translation_tokenizer, translation_model, device): + """helper function to translate prompt to English""" + + encoded_prompt = translation_tokenizer(prompt, return_tensors="pt").to(device) + generated_tokens = translation_model.generate(**encoded_prompt, max_new_tokens=1000) + en_trans = translation_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) + + return en_trans[0] + + +class MultilingualStableDiffusion(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using Stable Diffusion in different languages. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + detection_pipeline ([`pipeline`]): + Transformers pipeline to detect prompt's language. + translation_model ([`MBartForConditionalGeneration`]): + Model to translate prompt to English, if necessary. Please refer to the + [model card](https://huggingface.co/docs/transformers/model_doc/mbart) for details. + translation_tokenizer ([`MBart50TokenizerFast`]): + Tokenizer of the translation model. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + def __init__( + self, + detection_pipeline: pipeline, + translation_model: MBartForConditionalGeneration, + translation_tokenizer: MBart50TokenizerFast, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPFeatureExtractor, + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + self.register_modules( + detection_pipeline=detection_pipeline, + translation_model=translation_model, + translation_tokenizer=translation_tokenizer, + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, + `attention_head_dim` must be a multiple of `slice_size`. + """ + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = self.unet.config.attention_head_dim // 2 + self.unet.set_attention_slice(slice_size) + + def disable_attention_slicing(self): + r""" + Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go + back to computing attention in one step. + """ + # set slice_size = `None` to disable `attention slicing` + self.enable_attention_slicing(None) + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + height: int = 512, + width: int = 512, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. Can be in different languages. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + # detect language and translate if necessary + prompt_language = detect_language(self.detection_pipeline, prompt, batch_size) + if batch_size == 1 and prompt_language != "en": + prompt = translate_prompt(prompt, self.translation_tokenizer, self.translation_model, self.device) + + if isinstance(prompt, list): + for index in range(batch_size): + if prompt_language[index] != "en": + p = translate_prompt( + prompt[index], self.translation_tokenizer, self.translation_model, self.device + ) + prompt[index] = p + + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + + if text_input_ids.shape[-1] > self.tokenizer.model_max_length: + removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] + text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] + + # duplicate text embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = text_embeddings.shape + text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) + text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + # detect language and translate it if necessary + negative_prompt_language = detect_language(self.detection_pipeline, negative_prompt, batch_size) + if negative_prompt_language != "en": + negative_prompt = translate_prompt( + negative_prompt, self.translation_tokenizer, self.translation_model, self.device + ) + if isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + # detect language and translate it if necessary + if isinstance(negative_prompt, list): + negative_prompt_languages = detect_language(self.detection_pipeline, negative_prompt, batch_size) + for index in range(batch_size): + if negative_prompt_languages[index] != "en": + p = translate_prompt( + negative_prompt[index], self.translation_tokenizer, self.translation_model, self.device + ) + negative_prompt[index] = p + uncond_tokens = negative_prompt + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = uncond_embeddings.shape[1] + uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) + uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + # get the initial random noise unless the user supplied it + + # Unlike in other pipelines, latents need to be generated in the target device + # for 1-to-1 results reproducibility with the CompVis implementation. + # However this currently doesn't work in `mps`. + latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8) + latents_dtype = text_embeddings.dtype + if latents is None: + if self.device.type == "mps": + # randn does not work reproducibly on mps + latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to( + self.device + ) + else: + latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype) + else: + if latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + latents = latents.to(self.device) + + # set timesteps + self.scheduler.set_timesteps(num_inference_steps) + + # Some schedulers like PNDM have timesteps as arrays + # It's more optimized to move all timesteps to correct device beforehand + timesteps_tensor = self.scheduler.timesteps.to(self.device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + for i, t in enumerate(self.progress_bar(timesteps_tensor)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + latents = 1 / 0.18215 * latents + image = self.vae.decode(latents).sample + + image = (image / 2 + 0.5).clamp(0, 1) + + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to( + self.device + ) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype) + ) + else: + has_nsfw_concept = None + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/examples/community/sd_text2img_k_diffusion.py b/examples/community/sd_text2img_k_diffusion.py new file mode 100755 index 000000000000..9592f7879ff0 --- /dev/null +++ b/examples/community/sd_text2img_k_diffusion.py @@ -0,0 +1,479 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +from typing import Callable, List, Optional, Union + +import torch + +from diffusers import LMSDiscreteScheduler +from diffusers.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput +from diffusers.utils import is_accelerate_available, logging +from k_diffusion.external import CompVisDenoiser + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class ModelWrapper: + def __init__(self, model, alphas_cumprod): + self.model = model + self.alphas_cumprod = alphas_cumprod + + def apply_model(self, *args, **kwargs): + return self.model(*args, **kwargs).sample + + +class StableDiffusionPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + def __init__( + self, + vae, + text_encoder, + tokenizer, + unet, + scheduler, + safety_checker, + feature_extractor, + ): + super().__init__() + + if safety_checker is None: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + # get correct sigmas from LMS + scheduler = LMSDiscreteScheduler.from_config(scheduler.config) + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + + model = ModelWrapper(unet, scheduler.alphas_cumprod) + self.k_diffusion_model = CompVisDenoiser(model) + + def set_sampler(self, scheduler_type: str): + library = importlib.import_module("k_diffusion") + sampling = getattr(library, "sampling") + self.sampler = getattr(sampling, scheduler_type) + + def enable_xformers_memory_efficient_attention(self): + r""" + Enable memory efficient attention as implemented in xformers. + + When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference + time. Speed up at training time is not guaranteed. + + Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention + is used. + """ + self.unet.set_use_memory_efficient_attention_xformers(True) + + def disable_xformers_memory_efficient_attention(self): + r""" + Disable memory efficient attention as implemented in xformers. + """ + self.unet.set_use_memory_efficient_attention_xformers(False) + + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, + `attention_head_dim` must be a multiple of `slice_size`. + """ + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = self.unet.config.attention_head_dim // 2 + self.unet.set_attention_slice(slice_size) + + def disable_attention_slicing(self): + r""" + Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go + back to computing attention in one step. + """ + # set slice_size = `None` to disable `attention slicing` + self.enable_attention_slicing(None) + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + @property + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list(int)`): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + """ + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids + + if not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + text_embeddings = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + text_embeddings = text_embeddings[0] + + # duplicate text embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = text_embeddings.shape + text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) + text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + uncond_embeddings = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + uncond_embeddings = uncond_embeddings[0] + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = uncond_embeddings.shape[1] + uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) + uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + return text_embeddings + + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + else: + has_nsfw_concept = None + return image, has_nsfw_concept + + def decode_latents(self, latents): + latents = 1 / 0.18215 * latents + image = self.vae.decode(latents).sample + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def check_inputs(self, prompt, height, width, callback_steps): + if not isinstance(prompt, str) and not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // 8, width // 8) + if latents is None: + if device.type == "mps": + # randn does not work reproducibly on mps + latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device) + else: + latents = torch.randn(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + return latents + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + height: int = 512, + width: int = 512, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, height, width, callback_steps) + + # 2. Define call parameters + batch_size = 1 if isinstance(prompt, str) else len(prompt) + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = True + if guidance_scale <= 1.0: + raise ValueError("has to use guidance_scale") + + # 3. Encode input prompt + text_embeddings = self._encode_prompt( + prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=text_embeddings.device) + sigmas = self.scheduler.sigmas + + # 5. Prepare latent variables + num_channels_latents = self.unet.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + text_embeddings.dtype, + device, + generator, + latents, + ) + latents = latents * sigmas[0] + self.k_diffusion_model.sigmas = self.k_diffusion_model.sigmas.to(latents.device) + self.k_diffusion_model.log_sigmas = self.k_diffusion_model.log_sigmas.to(latents.device) + + def model_fn(x, t): + latent_model_input = torch.cat([x] * 2) + + noise_pred = self.k_diffusion_model(latent_model_input, t, encoder_hidden_states=text_embeddings) + + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + return noise_pred + + latents = self.sampler(model_fn, latents, sigmas) + + # 8. Post-processing + image = self.decode_latents(latents) + + # 9. Run safety checker + image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype) + + # 10. Convert to PIL + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/examples/community/seed_resize_stable_diffusion.py b/examples/community/seed_resize_stable_diffusion.py new file mode 100644 index 000000000000..92cd1c04f9f3 --- /dev/null +++ b/examples/community/seed_resize_stable_diffusion.py @@ -0,0 +1,366 @@ +""" + modified based on diffusion library from Huggingface: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +""" +import inspect +from typing import Callable, List, Optional, Union + +import torch + +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from diffusers.utils import logging +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class SeedResizeStableDiffusionPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPFeatureExtractor, + ): + super().__init__() + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, + `attention_head_dim` must be a multiple of `slice_size`. + """ + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = self.unet.config.attention_head_dim // 2 + self.unet.set_attention_slice(slice_size) + + def disable_attention_slicing(self): + r""" + Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go + back to computing attention in one step. + """ + # set slice_size = `None` to disable `attention slicing` + self.enable_attention_slicing(None) + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + height: int = 512, + width: int = 512, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + text_embeddings: Optional[torch.FloatTensor] = None, + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + + if text_input_ids.shape[-1] > self.tokenizer.model_max_length: + removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] + + if text_embeddings is None: + text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] + + # duplicate text embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = text_embeddings.shape + text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) + text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = uncond_embeddings.shape[1] + uncond_embeddings = uncond_embeddings.repeat(batch_size, num_images_per_prompt, 1) + uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + # get the initial random noise unless the user supplied it + + # Unlike in other pipelines, latents need to be generated in the target device + # for 1-to-1 results reproducibility with the CompVis implementation. + # However this currently doesn't work in `mps`. + latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8) + latents_shape_reference = (batch_size * num_images_per_prompt, self.unet.in_channels, 64, 64) + latents_dtype = text_embeddings.dtype + if latents is None: + if self.device.type == "mps": + # randn does not exist on mps + latents_reference = torch.randn( + latents_shape_reference, generator=generator, device="cpu", dtype=latents_dtype + ).to(self.device) + latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to( + self.device + ) + else: + latents_reference = torch.randn( + latents_shape_reference, generator=generator, device=self.device, dtype=latents_dtype + ) + latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype) + else: + if latents_reference.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + latents_reference = latents_reference.to(self.device) + latents = latents.to(self.device) + + # This is the key part of the pipeline where we + # try to ensure that the generated images w/ the same seed + # but different sizes actually result in similar images + dx = (latents_shape[3] - latents_shape_reference[3]) // 2 + dy = (latents_shape[2] - latents_shape_reference[2]) // 2 + w = latents_shape_reference[3] if dx >= 0 else latents_shape_reference[3] + 2 * dx + h = latents_shape_reference[2] if dy >= 0 else latents_shape_reference[2] + 2 * dy + tx = 0 if dx < 0 else dx + ty = 0 if dy < 0 else dy + dx = max(-dx, 0) + dy = max(-dy, 0) + # import pdb + # pdb.set_trace() + latents[:, :, ty : ty + h, tx : tx + w] = latents_reference[:, :, dy : dy + h, dx : dx + w] + + # set timesteps + self.scheduler.set_timesteps(num_inference_steps) + + # Some schedulers like PNDM have timesteps as arrays + # It's more optimized to move all timesteps to correct device beforehand + timesteps_tensor = self.scheduler.timesteps.to(self.device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + for i, t in enumerate(self.progress_bar(timesteps_tensor)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + latents = 1 / 0.18215 * latents + image = self.vae.decode(latents).sample + + image = (image / 2 + 0.5).clamp(0, 1) + + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to( + self.device + ) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype) + ) + else: + has_nsfw_concept = None + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/examples/community/speech_to_image_diffusion.py b/examples/community/speech_to_image_diffusion.py new file mode 100644 index 000000000000..17bc08e3c291 --- /dev/null +++ b/examples/community/speech_to_image_diffusion.py @@ -0,0 +1,261 @@ +import inspect +from typing import Callable, List, Optional, Union + +import torch + +from diffusers import ( + AutoencoderKL, + DDIMScheduler, + DiffusionPipeline, + LMSDiscreteScheduler, + PNDMScheduler, + UNet2DConditionModel, +) +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from diffusers.utils import logging +from transformers import ( + CLIPFeatureExtractor, + CLIPTextModel, + CLIPTokenizer, + WhisperForConditionalGeneration, + WhisperProcessor, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class SpeechToImagePipeline(DiffusionPipeline): + def __init__( + self, + speech_model: WhisperForConditionalGeneration, + speech_processor: WhisperProcessor, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPFeatureExtractor, + ): + super().__init__() + + if safety_checker is None: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + self.register_modules( + speech_model=speech_model, + speech_processor=speech_processor, + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + feature_extractor=feature_extractor, + ) + + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): + if slice_size == "auto": + slice_size = self.unet.config.attention_head_dim // 2 + self.unet.set_attention_slice(slice_size) + + def disable_attention_slicing(self): + self.enable_attention_slicing(None) + + @torch.no_grad() + def __call__( + self, + audio, + sampling_rate=16_000, + height: int = 512, + width: int = 512, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + **kwargs, + ): + inputs = self.speech_processor.feature_extractor( + audio, return_tensors="pt", sampling_rate=sampling_rate + ).input_features.to(self.device) + predicted_ids = self.speech_model.generate(inputs, max_length=480_000) + + prompt = self.speech_processor.tokenizer.batch_decode(predicted_ids, skip_special_tokens=True, normalize=True)[ + 0 + ] + + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + + if text_input_ids.shape[-1] > self.tokenizer.model_max_length: + removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] + text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] + + # duplicate text embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = text_embeddings.shape + text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) + text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = uncond_embeddings.shape[1] + uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) + uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + # get the initial random noise unless the user supplied it + + # Unlike in other pipelines, latents need to be generated in the target device + # for 1-to-1 results reproducibility with the CompVis implementation. + # However this currently doesn't work in `mps`. + latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8) + latents_dtype = text_embeddings.dtype + if latents is None: + if self.device.type == "mps": + # randn does not exist on mps + latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to( + self.device + ) + else: + latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype) + else: + if latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + latents = latents.to(self.device) + + # set timesteps + self.scheduler.set_timesteps(num_inference_steps) + + # Some schedulers like PNDM have timesteps as arrays + # It's more optimized to move all timesteps to correct device beforehand + timesteps_tensor = self.scheduler.timesteps.to(self.device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + for i, t in enumerate(self.progress_bar(timesteps_tensor)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + latents = 1 / 0.18215 * latents + image = self.vae.decode(latents).sample + + image = (image / 2 + 0.5).clamp(0, 1) + + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return image + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None) diff --git a/examples/community/stable_diffusion_mega.py b/examples/community/stable_diffusion_mega.py index 67d39f545544..67112b282b67 100644 --- a/examples/community/stable_diffusion_mega.py +++ b/examples/community/stable_diffusion_mega.py @@ -10,7 +10,7 @@ LMSDiscreteScheduler, PNDMScheduler, StableDiffusionImg2ImgPipeline, - StableDiffusionInpaintPipeline, + StableDiffusionInpaintPipelineLegacy, StableDiffusionPipeline, UNet2DConditionModel, ) @@ -42,11 +42,11 @@ class StableDiffusionMegaPipeline(DiffusionPipeline): [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionMegaSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. - Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPFeatureExtractor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ @@ -136,7 +136,7 @@ def inpaint( callback_steps: Optional[int] = 1, ): # For more information on how this function works, please see: https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion#diffusers.StableDiffusionImg2ImgPipeline - return StableDiffusionInpaintPipeline(**self.components)( + return StableDiffusionInpaintPipelineLegacy(**self.components)( prompt=prompt, init_image=init_image, mask_image=mask_image, diff --git a/examples/community/text_inpainting.py b/examples/community/text_inpainting.py new file mode 100644 index 000000000000..a4368f8b4365 --- /dev/null +++ b/examples/community/text_inpainting.py @@ -0,0 +1,320 @@ +from typing import Callable, List, Optional, Union + +import torch + +import PIL +from diffusers.configuration_utils import FrozenDict +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion import StableDiffusionInpaintPipeline +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from diffusers.utils import deprecate, is_accelerate_available, logging +from transformers import ( + CLIPFeatureExtractor, + CLIPSegForImageSegmentation, + CLIPSegProcessor, + CLIPTextModel, + CLIPTokenizer, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class TextInpainting(DiffusionPipeline): + r""" + Pipeline for text based inpainting using Stable Diffusion. + Uses CLIPSeg to get a mask from the given text, then calls the Inpainting pipeline with the generated mask + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + segmentation_model ([`CLIPSegForImageSegmentation`]): + CLIPSeg Model to generate mask from the given text. Please refer to the [model card]() for details. + segmentation_processor ([`CLIPSegProcessor`]): + CLIPSeg processor to get image, text features to translate prompt to English, if necessary. Please refer to the + [model card](https://huggingface.co/docs/transformers/model_doc/clipseg) for details. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + def __init__( + self, + segmentation_model: CLIPSegForImageSegmentation, + segmentation_processor: CLIPSegProcessor, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPFeatureExtractor, + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "skip_prk_steps") and scheduler.config.skip_prk_steps is False: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration" + " `skip_prk_steps`. `skip_prk_steps` should be set to True in the configuration file. Please make" + " sure to update the config accordingly as not setting `skip_prk_steps` in the config might lead to" + " incorrect results in future versions. If you have downloaded this checkpoint from the Hugging Face" + " Hub, it would be very nice if you could open a Pull request for the" + " `scheduler/scheduler_config.json` file" + ) + deprecate("skip_prk_steps not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["skip_prk_steps"] = True + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + self.register_modules( + segmentation_model=segmentation_model, + segmentation_processor=segmentation_processor, + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, + `attention_head_dim` must be a multiple of `slice_size`. + """ + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = self.unet.config.attention_head_dim // 2 + self.unet.set_attention_slice(slice_size) + + def disable_attention_slicing(self): + r""" + Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go + back to computing attention in one step. + """ + # set slice_size = `None` to disable `attention slicing` + self.enable_attention_slicing(None) + + def enable_sequential_cpu_offload(self): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device("cuda") + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + def enable_xformers_memory_efficient_attention(self): + r""" + Enable memory efficient attention as implemented in xformers. + + When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference + time. Speed up at training time is not guaranteed. + + Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention + is used. + """ + self.unet.set_use_memory_efficient_attention_xformers(True) + + def disable_xformers_memory_efficient_attention(self): + r""" + Disable memory efficient attention as implemented in xformers. + """ + self.unet.set_use_memory_efficient_attention_xformers(False) + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + image: Union[torch.FloatTensor, PIL.Image.Image], + text: str, + height: int = 512, + width: int = 512, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + image (`PIL.Image.Image`): + `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will + be masked out with `mask_image` and repainted according to `prompt`. + text (`str``): + The text to use to generate the mask. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + + # We use the input text to generate the mask + inputs = self.segmentation_processor( + text=[text], images=[image], padding="max_length", return_tensors="pt" + ).to(self.device) + outputs = self.segmentation_model(**inputs) + mask = torch.sigmoid(outputs.logits).cpu().detach().unsqueeze(-1).numpy() + mask_pil = self.numpy_to_pil(mask)[0].resize(image.size) + + # Run inpainting pipeline with the generated mask + inpainting_pipeline = StableDiffusionInpaintPipeline( + vae=self.vae, + text_encoder=self.text_encoder, + tokenizer=self.tokenizer, + unet=self.unet, + scheduler=self.scheduler, + safety_checker=self.safety_checker, + feature_extractor=self.feature_extractor, + ) + return inpainting_pipeline( + prompt=prompt, + image=image, + mask_image=mask_pil, + height=height, + width=width, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + eta=eta, + generator=generator, + latents=latents, + output_type=output_type, + return_dict=return_dict, + callback=callback, + callback_steps=callback_steps, + ) diff --git a/examples/community/wildcard_stable_diffusion.py b/examples/community/wildcard_stable_diffusion.py new file mode 100644 index 000000000000..282be8e48b7d --- /dev/null +++ b/examples/community/wildcard_stable_diffusion.py @@ -0,0 +1,418 @@ +import inspect +import os +import random +import re +from dataclasses import dataclass +from typing import Callable, Dict, List, Optional, Union + +import torch + +from diffusers.configuration_utils import FrozenDict +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from diffusers.utils import deprecate, logging +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +global_re_wildcard = re.compile(r"__([^_]*)__") + + +def get_filename(path: str): + # this doesn't work on Windows + return os.path.basename(path).split(".txt")[0] + + +def read_wildcard_values(path: str): + with open(path, encoding="utf8") as f: + return f.read().splitlines() + + +def grab_wildcard_values(wildcard_option_dict: Dict[str, List[str]] = {}, wildcard_files: List[str] = []): + for wildcard_file in wildcard_files: + filename = get_filename(wildcard_file) + read_values = read_wildcard_values(wildcard_file) + if filename not in wildcard_option_dict: + wildcard_option_dict[filename] = [] + wildcard_option_dict[filename].extend(read_values) + return wildcard_option_dict + + +def replace_prompt_with_wildcards( + prompt: str, wildcard_option_dict: Dict[str, List[str]] = {}, wildcard_files: List[str] = [] +): + new_prompt = prompt + + # get wildcard options + wildcard_option_dict = grab_wildcard_values(wildcard_option_dict, wildcard_files) + + for m in global_re_wildcard.finditer(new_prompt): + wildcard_value = m.group() + replace_value = random.choice(wildcard_option_dict[wildcard_value.strip("__")]) + new_prompt = new_prompt.replace(wildcard_value, replace_value, 1) + + return new_prompt + + +@dataclass +class WildcardStableDiffusionOutput(StableDiffusionPipelineOutput): + prompts: List[str] + + +class WildcardStableDiffusionPipeline(DiffusionPipeline): + r""" + Example Usage: + pipe = WildcardStableDiffusionPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", + revision="fp16", + torch_dtype=torch.float16, + ) + prompt = "__animal__ sitting on a __object__ wearing a __clothing__" + out = pipe( + prompt, + wildcard_option_dict={ + "clothing":["hat", "shirt", "scarf", "beret"] + }, + wildcard_files=["object.txt", "animal.txt"], + num_prompt_samples=1 + ) + + + Pipeline for text-to-image generation with wild cards using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPFeatureExtractor, + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + height: int = 512, + width: int = 512, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + wildcard_option_dict: Dict[str, List[str]] = {}, + wildcard_files: List[str] = [], + num_prompt_samples: Optional[int] = 1, + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + wildcard_option_dict (Dict[str, List[str]]): + dict with key as `wildcard` and values as a list of possible replacements. For example if a prompt, "A __animal__ sitting on a chair". A wildcard_option_dict can provide possible values for "animal" like this: {"animal":["dog", "cat", "fox"]} + wildcard_files: (List[str]) + List of filenames of txt files for wildcard replacements. For example if a prompt, "A __animal__ sitting on a chair". A file can be provided ["animal.txt"] + num_prompt_samples: int + Number of times to sample wildcards for each prompt provided + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + + if isinstance(prompt, str): + prompt = [ + replace_prompt_with_wildcards(prompt, wildcard_option_dict, wildcard_files) + for i in range(num_prompt_samples) + ] + batch_size = len(prompt) + elif isinstance(prompt, list): + prompt_list = [] + for p in prompt: + for i in range(num_prompt_samples): + prompt_list.append(replace_prompt_with_wildcards(p, wildcard_option_dict, wildcard_files)) + prompt = prompt_list + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + + if text_input_ids.shape[-1] > self.tokenizer.model_max_length: + removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] + text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] + + # duplicate text embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = text_embeddings.shape + text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) + text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = uncond_embeddings.shape[1] + uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) + uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + # get the initial random noise unless the user supplied it + + # Unlike in other pipelines, latents need to be generated in the target device + # for 1-to-1 results reproducibility with the CompVis implementation. + # However this currently doesn't work in `mps`. + latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8) + latents_dtype = text_embeddings.dtype + if latents is None: + if self.device.type == "mps": + # randn does not exist on mps + latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to( + self.device + ) + else: + latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype) + else: + if latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + latents = latents.to(self.device) + + # set timesteps + self.scheduler.set_timesteps(num_inference_steps) + + # Some schedulers like PNDM have timesteps as arrays + # It's more optimized to move all timesteps to correct device beforehand + timesteps_tensor = self.scheduler.timesteps.to(self.device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + for i, t in enumerate(self.progress_bar(timesteps_tensor)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + latents = 1 / 0.18215 * latents + image = self.vae.decode(latents).sample + + image = (image / 2 + 0.5).clamp(0, 1) + + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to( + self.device + ) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype) + ) + else: + has_nsfw_concept = None + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image, has_nsfw_concept) + + return WildcardStableDiffusionOutput(images=image, nsfw_content_detected=has_nsfw_concept, prompts=prompt) diff --git a/examples/dreambooth/README.md b/examples/dreambooth/README.md index 7094570ad5eb..7aaf1bc46c22 100644 --- a/examples/dreambooth/README.md +++ b/examples/dreambooth/README.md @@ -4,13 +4,12 @@ The `train_dreambooth.py` script shows how to implement the training procedure and adapt it for stable diffusion. -## Running locally +## Running locally with PyTorch ### Installing the dependencies Before running the scripts, make sure to install the library's training dependencies: ```bash -pip install git+https://github.com/huggingface/diffusers.git pip install -U -r requirements.txt ``` @@ -88,11 +87,12 @@ accelerate launch train_dreambooth.py \ --max_train_steps=800 ``` + ### Training on a 16GB GPU: With the help of gradient checkpointing and the 8-bit optimizer from bitsandbytes it's possible to run train dreambooth on a 16GB GPU. -Install `bitsandbytes` with `pip install bitsandbytes` +To install `bitandbytes` please refer to this [readme](https://github.com/TimDettmers/bitsandbytes#requirements--installation). ```bash export MODEL_NAME="CompVis/stable-diffusion-v1-4" @@ -141,7 +141,7 @@ export INSTANCE_DIR="path-to-instance-images" export CLASS_DIR="path-to-class-images" export OUTPUT_DIR="path-to-save-model" -accelerate launch train_dreambooth.py \ +accelerate launch --mixed_precision="fp16" train_dreambooth.py \ --pretrained_model_name_or_path=$MODEL_NAME \ --instance_data_dir=$INSTANCE_DIR \ --class_data_dir=$CLASS_DIR \ @@ -151,16 +151,49 @@ accelerate launch train_dreambooth.py \ --class_prompt="a photo of dog" \ --resolution=512 \ --train_batch_size=1 \ + --sample_batch_size=1 \ --gradient_accumulation_steps=1 --gradient_checkpointing \ --learning_rate=5e-6 \ --lr_scheduler="constant" \ --lr_warmup_steps=0 \ --num_class_images=200 \ - --max_train_steps=800 \ - --mixed_precision=fp16 + --max_train_steps=800 +``` + +### Fine-tune text encoder with the UNet. + +The script also allows to fine-tune the `text_encoder` along with the `unet`. It's been observed experimentally that fine-tuning `text_encoder` gives much better results especially on faces. +Pass the `--train_text_encoder` argument to the script to enable training `text_encoder`. + +___Note: Training text encoder requires more memory, with this option the training won't fit on 16GB GPU. It needs at least 24GB VRAM.___ + +```bash +export MODEL_NAME="CompVis/stable-diffusion-v1-4" +export INSTANCE_DIR="path-to-instance-images" +export CLASS_DIR="path-to-class-images" +export OUTPUT_DIR="path-to-save-model" + +accelerate launch train_dreambooth.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_text_encoder \ + --instance_data_dir=$INSTANCE_DIR \ + --class_data_dir=$CLASS_DIR \ + --output_dir=$OUTPUT_DIR \ + --with_prior_preservation --prior_loss_weight=1.0 \ + --instance_prompt="a photo of sks dog" \ + --class_prompt="a photo of dog" \ + --resolution=512 \ + --train_batch_size=1 \ + --use_8bit_adam \ + --gradient_checkpointing \ + --learning_rate=2e-6 \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --num_class_images=200 \ + --max_train_steps=800 ``` -## Inference +### Inference Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline`. Make sure to include the `identifier`(e.g. sks in above example) in your prompt. @@ -176,3 +209,85 @@ image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0] image.save("dog-bucket.png") ``` + + +## Running with Flax/JAX + +For faster training on TPUs and GPUs you can leverage the flax training example. Follow the instructions above to get the model and dataset before running the script. + +____Note: The flax example don't yet support features like gradient checkpoint, gradient accumulation etc, so to use flax for faster training we will need >30GB cards.___ + + +Before running the scripts, make sure to install the library's training dependencies: + +```bash +pip install -U -r requirements_flax.txt +``` + + +### Training without prior preservation loss + +```bash +export MODEL_NAME="duongna/stable-diffusion-v1-4-flax" +export INSTANCE_DIR="path-to-instance-images" +export OUTPUT_DIR="path-to-save-model" + +python train_dreambooth_flax.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --instance_data_dir=$INSTANCE_DIR \ + --output_dir=$OUTPUT_DIR \ + --instance_prompt="a photo of sks dog" \ + --resolution=512 \ + --train_batch_size=1 \ + --learning_rate=5e-6 \ + --max_train_steps=400 +``` + + +### Training with prior preservation loss + +```bash +export MODEL_NAME="duongna/stable-diffusion-v1-4-flax" +export INSTANCE_DIR="path-to-instance-images" +export CLASS_DIR="path-to-class-images" +export OUTPUT_DIR="path-to-save-model" + +python train_dreambooth_flax.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --instance_data_dir=$INSTANCE_DIR \ + --class_data_dir=$CLASS_DIR \ + --output_dir=$OUTPUT_DIR \ + --with_prior_preservation --prior_loss_weight=1.0 \ + --instance_prompt="a photo of sks dog" \ + --class_prompt="a photo of dog" \ + --resolution=512 \ + --train_batch_size=1 \ + --learning_rate=5e-6 \ + --num_class_images=200 \ + --max_train_steps=800 +``` + + +### Fine-tune text encoder with the UNet. + +```bash +export MODEL_NAME="duongna/stable-diffusion-v1-4-flax" +export INSTANCE_DIR="path-to-instance-images" +export CLASS_DIR="path-to-class-images" +export OUTPUT_DIR="path-to-save-model" + +python train_dreambooth_flax.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_text_encoder \ + --instance_data_dir=$INSTANCE_DIR \ + --class_data_dir=$CLASS_DIR \ + --output_dir=$OUTPUT_DIR \ + --with_prior_preservation --prior_loss_weight=1.0 \ + --instance_prompt="a photo of sks dog" \ + --class_prompt="a photo of dog" \ + --resolution=512 \ + --train_batch_size=1 \ + --learning_rate=2e-6 \ + --num_class_images=200 \ + --max_train_steps=800 +``` diff --git a/examples/dreambooth/requirements.txt b/examples/dreambooth/requirements.txt index c0649bbe2bef..2174abfabbff 100644 --- a/examples/dreambooth/requirements.txt +++ b/examples/dreambooth/requirements.txt @@ -1,3 +1,4 @@ +diffusers>==0.5.0 accelerate torchvision transformers>=4.21.0 diff --git a/examples/dreambooth/requirements_flax.txt b/examples/dreambooth/requirements_flax.txt new file mode 100644 index 000000000000..a1adc6632f4c --- /dev/null +++ b/examples/dreambooth/requirements_flax.txt @@ -0,0 +1,9 @@ +diffusers>==0.5.1 +transformers>=4.21.0 +flax +optax +torch +torchvision +ftfy +tensorboard +modelcards \ No newline at end of file diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index ca39aeff236b..1f6c730f2baf 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -1,4 +1,6 @@ import argparse +import hashlib +import itertools import math import os from pathlib import Path @@ -24,7 +26,7 @@ logger = get_logger(__name__) -def parse_args(): +def parse_args(input_args=None): parser = argparse.ArgumentParser(description="Simple example of a training script.") parser.add_argument( "--pretrained_model_name_or_path", @@ -33,6 +35,13 @@ def parse_args(): required=True, help="Path to pretrained model or model identifier from huggingface.co/models.", ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) parser.add_argument( "--tokenizer_name", type=str, @@ -57,6 +66,7 @@ def parse_args(): "--instance_prompt", type=str, default=None, + required=True, help="The prompt with identifier specifying the instance", ) parser.add_argument( @@ -100,6 +110,7 @@ def parse_args(): parser.add_argument( "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution" ) + parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder") parser.add_argument( "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." ) @@ -176,29 +187,35 @@ def parse_args(): parser.add_argument( "--mixed_precision", type=str, - default="no", + default=None, choices=["no", "fp16", "bf16"], help=( - "Whether to use mixed precision. Choose" - "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." - "and an Nvidia Ampere GPU." + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." ), ) parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") - args = parser.parse_args() + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) if env_local_rank != -1 and env_local_rank != args.local_rank: args.local_rank = env_local_rank - if args.instance_data_dir is None: - raise ValueError("You must specify a train data directory.") - if args.with_prior_preservation: if args.class_data_dir is None: raise ValueError("You must specify a data directory for class images.") if args.class_prompt is None: raise ValueError("You must specify prompt for class images.") + else: + if args.class_data_dir is not None: + logger.warning("You need not use --class_data_dir without --with_prior_preservation.") + if args.class_prompt is not None: + logger.warning("You need not use --class_prompt without --with_prior_preservation.") return args @@ -309,8 +326,7 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: return f"{organization}/{model_id}" -def main(): - args = parse_args() +def main(args): logging_dir = Path(args.output_dir, args.logging_dir) accelerator = Accelerator( @@ -320,6 +336,15 @@ def main(): logging_dir=logging_dir, ) + # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate + # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models. + # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate. + if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: + raise ValueError( + "Gradient accumulation is not supported when training the text encoder in distributed training. " + "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." + ) + if args.seed is not None: set_seed(args.seed) @@ -332,7 +357,10 @@ def main(): if cur_class_images < args.num_class_images: torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 pipeline = StableDiffusionPipeline.from_pretrained( - args.pretrained_model_name_or_path, torch_dtype=torch_dtype + args.pretrained_model_name_or_path, + torch_dtype=torch_dtype, + safety_checker=None, + revision=args.revision, ) pipeline.set_progress_bar_config(disable=True) @@ -351,7 +379,9 @@ def main(): images = pipeline(example["prompt"]).images for i, image in enumerate(images): - image.save(class_images_dir / f"{example['index'][i] + cur_class_images}.jpg") + hash_image = hashlib.sha1(image.tobytes()).hexdigest() + image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" + image.save(image_filename) del pipeline if torch.cuda.is_available(): @@ -376,17 +406,42 @@ def main(): # Load the tokenizer if args.tokenizer_name: - tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) + tokenizer = CLIPTokenizer.from_pretrained( + args.tokenizer_name, + revision=args.revision, + ) elif args.pretrained_model_name_or_path: - tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer") + tokenizer = CLIPTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + ) # Load models and create wrapper for stable diffusion - text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder") - vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae") - unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet") + text_encoder = CLIPTextModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="text_encoder", + revision=args.revision, + ) + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="vae", + revision=args.revision, + ) + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="unet", + revision=args.revision, + ) + + vae.requires_grad_(False) + if not args.train_text_encoder: + text_encoder.requires_grad_(False) if args.gradient_checkpointing: unet.enable_gradient_checkpointing() + if args.train_text_encoder: + text_encoder.gradient_checkpointing_enable() if args.scale_lr: args.learning_rate = ( @@ -406,17 +461,18 @@ def main(): else: optimizer_class = torch.optim.AdamW + params_to_optimize = ( + itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters() + ) optimizer = optimizer_class( - unet.parameters(), # only optimize unet + params_to_optimize, lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, ) - noise_scheduler = DDPMScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000 - ) + noise_scheduler = DDPMScheduler.from_config(args.pretrained_model_name_or_path, subfolder="scheduler") train_dataset = DreamBoothDataset( instance_data_root=args.instance_data_dir, @@ -441,7 +497,12 @@ def collate_fn(examples): pixel_values = torch.stack(pixel_values) pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() - input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids + input_ids = tokenizer.pad( + {"input_ids": input_ids}, + padding="max_length", + max_length=tokenizer.model_max_length, + return_tensors="pt", + ).input_ids batch = { "input_ids": input_ids, @@ -450,7 +511,7 @@ def collate_fn(examples): return batch train_dataloader = torch.utils.data.DataLoader( - train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn + train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, num_workers=1 ) # Scheduler and math around the number of training steps. @@ -467,21 +528,27 @@ def collate_fn(examples): num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, ) - unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, optimizer, train_dataloader, lr_scheduler - ) + if args.train_text_encoder: + unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader, lr_scheduler + ) + else: + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, optimizer, train_dataloader, lr_scheduler + ) weight_dtype = torch.float32 - if args.mixed_precision == "fp16": + if accelerator.mixed_precision == "fp16": weight_dtype = torch.float16 - elif args.mixed_precision == "bf16": + elif accelerator.mixed_precision == "bf16": weight_dtype = torch.bfloat16 # Move text_encode and vae to gpu. # For mixed precision training we cast the text_encoder and vae weights to half-precision # as these models are only used for inference, keeping weights in full precision is not required. - text_encoder.to(accelerator.device, dtype=weight_dtype) vae.to(accelerator.device, dtype=weight_dtype) + if not args.train_text_encoder: + text_encoder.to(accelerator.device, dtype=weight_dtype) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) @@ -513,12 +580,13 @@ def collate_fn(examples): for epoch in range(args.num_train_epochs): unet.train() + if args.train_text_encoder: + text_encoder.train() for step, batch in enumerate(train_dataloader): with accelerator.accumulate(unet): # Convert images to latent space - with torch.no_grad(): - latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() - latents = latents * 0.18215 + latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() + latents = latents * 0.18215 # Sample noise that we'll add to the latents noise = torch.randn_like(latents) @@ -532,8 +600,7 @@ def collate_fn(examples): noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # Get the text embedding for conditioning - with torch.no_grad(): - encoder_hidden_states = text_encoder(batch["input_ids"])[0] + encoder_hidden_states = text_encoder(batch["input_ids"])[0] # Predict the noise residual noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample @@ -556,7 +623,12 @@ def collate_fn(examples): accelerator.backward(loss) if accelerator.sync_gradients: - accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) + params_to_clip = ( + itertools.chain(unet.parameters(), text_encoder.parameters()) + if args.train_text_encoder + else unet.parameters() + ) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() lr_scheduler.step() optimizer.zero_grad() @@ -578,7 +650,10 @@ def collate_fn(examples): # Create the pipeline using using the trained modules and save it. if accelerator.is_main_process: pipeline = StableDiffusionPipeline.from_pretrained( - args.pretrained_model_name_or_path, unet=accelerator.unwrap_model(unet) + args.pretrained_model_name_or_path, + unet=accelerator.unwrap_model(unet), + text_encoder=accelerator.unwrap_model(text_encoder), + revision=args.revision, ) pipeline.save_pretrained(args.output_dir) @@ -589,4 +664,5 @@ def collate_fn(examples): if __name__ == "__main__": - main() + args = parse_args() + main(args) diff --git a/examples/dreambooth/train_dreambooth_flax.py b/examples/dreambooth/train_dreambooth_flax.py new file mode 100644 index 000000000000..6606af4f17f7 --- /dev/null +++ b/examples/dreambooth/train_dreambooth_flax.py @@ -0,0 +1,652 @@ +import argparse +import hashlib +import logging +import math +import os +from pathlib import Path +from typing import Optional + +import numpy as np +import torch +import torch.utils.checkpoint +from torch.utils.data import Dataset + +import jax +import jax.numpy as jnp +import optax +import transformers +from diffusers import ( + FlaxAutoencoderKL, + FlaxDDPMScheduler, + FlaxPNDMScheduler, + FlaxStableDiffusionPipeline, + FlaxUNet2DConditionModel, +) +from diffusers.pipelines.stable_diffusion import FlaxStableDiffusionSafetyChecker +from flax import jax_utils +from flax.training import train_state +from flax.training.common_utils import shard +from huggingface_hub import HfFolder, Repository, whoami +from PIL import Image +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel, set_seed + + +logger = logging.getLogger(__name__) + + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--instance_data_dir", + type=str, + default=None, + required=True, + help="A folder containing the training data of instance images.", + ) + parser.add_argument( + "--class_data_dir", + type=str, + default=None, + required=False, + help="A folder containing the training data of class images.", + ) + parser.add_argument( + "--instance_prompt", + type=str, + default=None, + help="The prompt with identifier specifying the instance", + ) + parser.add_argument( + "--class_prompt", + type=str, + default=None, + help="The prompt to specify images in the same class as provided instance images.", + ) + parser.add_argument( + "--with_prior_preservation", + default=False, + action="store_true", + help="Flag to add prior preservation loss.", + ) + parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") + parser.add_argument( + "--num_class_images", + type=int, + default=100, + help=( + "Minimal class images for prior preservation loss. If not have enough images, additional images will be" + " sampled with class_prompt." + ), + ) + parser.add_argument( + "--output_dir", + type=str, + default="text-inversion-model", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--seed", type=int, default=0, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution" + ) + parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder") + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument( + "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=5e-6, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default="no", + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose" + "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." + "and an Nvidia Ampere GPU." + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + + args = parser.parse_args() + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + if args.instance_data_dir is None: + raise ValueError("You must specify a train data directory.") + + if args.with_prior_preservation: + if args.class_data_dir is None: + raise ValueError("You must specify a data directory for class images.") + if args.class_prompt is None: + raise ValueError("You must specify prompt for class images.") + + return args + + +class DreamBoothDataset(Dataset): + """ + A dataset to prepare the instance and class images with the prompts for fine-tuning the model. + It pre-processes the images and the tokenizes prompts. + """ + + def __init__( + self, + instance_data_root, + instance_prompt, + tokenizer, + class_data_root=None, + class_prompt=None, + size=512, + center_crop=False, + ): + self.size = size + self.center_crop = center_crop + self.tokenizer = tokenizer + + self.instance_data_root = Path(instance_data_root) + if not self.instance_data_root.exists(): + raise ValueError("Instance images root doesn't exists.") + + self.instance_images_path = list(Path(instance_data_root).iterdir()) + self.num_instance_images = len(self.instance_images_path) + self.instance_prompt = instance_prompt + self._length = self.num_instance_images + + if class_data_root is not None: + self.class_data_root = Path(class_data_root) + self.class_data_root.mkdir(parents=True, exist_ok=True) + self.class_images_path = list(self.class_data_root.iterdir()) + self.num_class_images = len(self.class_images_path) + self._length = max(self.num_class_images, self.num_instance_images) + self.class_prompt = class_prompt + else: + self.class_data_root = None + + self.image_transforms = transforms.Compose( + [ + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + def __len__(self): + return self._length + + def __getitem__(self, index): + example = {} + instance_image = Image.open(self.instance_images_path[index % self.num_instance_images]) + if not instance_image.mode == "RGB": + instance_image = instance_image.convert("RGB") + example["instance_images"] = self.image_transforms(instance_image) + example["instance_prompt_ids"] = self.tokenizer( + self.instance_prompt, + padding="do_not_pad", + truncation=True, + max_length=self.tokenizer.model_max_length, + ).input_ids + + if self.class_data_root: + class_image = Image.open(self.class_images_path[index % self.num_class_images]) + if not class_image.mode == "RGB": + class_image = class_image.convert("RGB") + example["class_images"] = self.image_transforms(class_image) + example["class_prompt_ids"] = self.tokenizer( + self.class_prompt, + padding="do_not_pad", + truncation=True, + max_length=self.tokenizer.model_max_length, + ).input_ids + + return example + + +class PromptDataset(Dataset): + "A simple dataset to prepare the prompts to generate class images on multiple GPUs." + + def __init__(self, prompt, num_samples): + self.prompt = prompt + self.num_samples = num_samples + + def __len__(self): + return self.num_samples + + def __getitem__(self, index): + example = {} + example["prompt"] = self.prompt + example["index"] = index + return example + + +def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): + if token is None: + token = HfFolder.get_token() + if organization is None: + username = whoami(token)["name"] + return f"{username}/{model_id}" + else: + return f"{organization}/{model_id}" + + +def get_params_to_save(params): + return jax.device_get(jax.tree_util.tree_map(lambda x: x[0], params)) + + +def main(): + args = parse_args() + + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + # Setup logging, we only want one process per machine to log things on the screen. + logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR) + if jax.process_index() == 0: + transformers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + + if args.seed is not None: + set_seed(args.seed) + + rng = jax.random.PRNGKey(args.seed) + + if args.with_prior_preservation: + class_images_dir = Path(args.class_data_dir) + if not class_images_dir.exists(): + class_images_dir.mkdir(parents=True) + cur_class_images = len(list(class_images_dir.iterdir())) + + if cur_class_images < args.num_class_images: + pipeline, params = FlaxStableDiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, safety_checker=None + ) + pipeline.set_progress_bar_config(disable=True) + + num_new_images = args.num_class_images - cur_class_images + logger.info(f"Number of class images to sample: {num_new_images}.") + + sample_dataset = PromptDataset(args.class_prompt, num_new_images) + total_sample_batch_size = args.sample_batch_size * jax.local_device_count() + sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=total_sample_batch_size) + + for example in tqdm( + sample_dataloader, desc="Generating class images", disable=not jax.process_index() == 0 + ): + prompt_ids = pipeline.prepare_inputs(example["prompt"]) + prompt_ids = shard(prompt_ids) + p_params = jax_utils.replicate(params) + rng = jax.random.split(rng)[0] + sample_rng = jax.random.split(rng, jax.device_count()) + images = pipeline(prompt_ids, p_params, sample_rng, jit=True).images + images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:]) + images = pipeline.numpy_to_pil(np.array(images)) + + for i, image in enumerate(images): + hash_image = hashlib.sha1(image.tobytes()).hexdigest() + image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" + image.save(image_filename) + + del pipeline + + # Handle the repository creation + if jax.process_index() == 0: + if args.push_to_hub: + if args.hub_model_id is None: + repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) + else: + repo_name = args.hub_model_id + repo = Repository(args.output_dir, clone_from=repo_name) + + with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: + if "step_*" not in gitignore: + gitignore.write("step_*\n") + if "epoch_*" not in gitignore: + gitignore.write("epoch_*\n") + elif args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + # Load the tokenizer and add the placeholder token as a additional special token + if args.tokenizer_name: + tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) + elif args.pretrained_model_name_or_path: + tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer") + + train_dataset = DreamBoothDataset( + instance_data_root=args.instance_data_dir, + instance_prompt=args.instance_prompt, + class_data_root=args.class_data_dir if args.with_prior_preservation else None, + class_prompt=args.class_prompt, + tokenizer=tokenizer, + size=args.resolution, + center_crop=args.center_crop, + ) + + def collate_fn(examples): + input_ids = [example["instance_prompt_ids"] for example in examples] + pixel_values = [example["instance_images"] for example in examples] + + # Concat class and instance examples for prior preservation. + # We do this to avoid doing two forward passes. + if args.with_prior_preservation: + input_ids += [example["class_prompt_ids"] for example in examples] + pixel_values += [example["class_images"] for example in examples] + + pixel_values = torch.stack(pixel_values) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + + input_ids = tokenizer.pad( + {"input_ids": input_ids}, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt" + ).input_ids + + batch = { + "input_ids": input_ids, + "pixel_values": pixel_values, + } + batch = {k: v.numpy() for k, v in batch.items()} + return batch + + total_train_batch_size = args.train_batch_size * jax.local_device_count() + train_dataloader = torch.utils.data.DataLoader( + train_dataset, batch_size=total_train_batch_size, shuffle=True, collate_fn=collate_fn, drop_last=True + ) + + weight_dtype = jnp.float32 + if args.mixed_precision == "fp16": + weight_dtype = jnp.float16 + elif args.mixed_precision == "bf16": + weight_dtype = jnp.bfloat16 + + # Load models and create wrapper for stable diffusion + text_encoder = FlaxCLIPTextModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", dtype=weight_dtype + ) + vae, vae_params = FlaxAutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, subfolder="vae", dtype=weight_dtype + ) + unet, unet_params = FlaxUNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="unet", dtype=weight_dtype + ) + + # Optimization + if args.scale_lr: + args.learning_rate = args.learning_rate * total_train_batch_size + + constant_scheduler = optax.constant_schedule(args.learning_rate) + + adamw = optax.adamw( + learning_rate=constant_scheduler, + b1=args.adam_beta1, + b2=args.adam_beta2, + eps=args.adam_epsilon, + weight_decay=args.adam_weight_decay, + ) + + optimizer = optax.chain( + optax.clip_by_global_norm(args.max_grad_norm), + adamw, + ) + + unet_state = train_state.TrainState.create(apply_fn=unet.__call__, params=unet_params, tx=optimizer) + text_encoder_state = train_state.TrainState.create( + apply_fn=text_encoder.__call__, params=text_encoder.params, tx=optimizer + ) + + noise_scheduler = FlaxDDPMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000 + ) + + # Initialize our training + train_rngs = jax.random.split(rng, jax.local_device_count()) + + def train_step(unet_state, text_encoder_state, vae_params, batch, train_rng): + dropout_rng, sample_rng, new_train_rng = jax.random.split(train_rng, 3) + + if args.train_text_encoder: + params = {"text_encoder": text_encoder_state.params, "unet": unet_state.params} + else: + params = {"unet": unet_state.params} + + def compute_loss(params): + # Convert images to latent space + vae_outputs = vae.apply( + {"params": vae_params}, batch["pixel_values"], deterministic=True, method=vae.encode + ) + latents = vae_outputs.latent_dist.sample(sample_rng) + # (NHWC) -> (NCHW) + latents = jnp.transpose(latents, (0, 3, 1, 2)) + latents = latents * 0.18215 + + # Sample noise that we'll add to the latents + noise_rng, timestep_rng = jax.random.split(sample_rng) + noise = jax.random.normal(noise_rng, latents.shape) + # Sample a random timestep for each image + bsz = latents.shape[0] + timesteps = jax.random.randint( + timestep_rng, + (bsz,), + 0, + noise_scheduler.config.num_train_timesteps, + ) + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Get the text embedding for conditioning + if args.train_text_encoder: + encoder_hidden_states = text_encoder_state.apply_fn( + batch["input_ids"], params=params["text_encoder"], dropout_rng=dropout_rng, train=True + )[0] + else: + encoder_hidden_states = text_encoder( + batch["input_ids"], params=text_encoder_state.params, train=False + )[0] + + # Predict the noise residual + unet_outputs = unet.apply( + {"params": params["unet"]}, noisy_latents, timesteps, encoder_hidden_states, train=True + ) + noise_pred = unet_outputs.sample + + if args.with_prior_preservation: + # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. + noise_pred, noise_pred_prior = jnp.split(noise_pred, 2, axis=0) + noise, noise_prior = jnp.split(noise, 2, axis=0) + + # Compute instance loss + loss = (noise - noise_pred) ** 2 + loss = loss.mean() + + # Compute prior loss + prior_loss = (noise_prior - noise_pred_prior) ** 2 + prior_loss = prior_loss.mean() + + # Add the prior loss to the instance loss. + loss = loss + args.prior_loss_weight * prior_loss + else: + loss = (noise - noise_pred) ** 2 + loss = loss.mean() + + return loss + + grad_fn = jax.value_and_grad(compute_loss) + loss, grad = grad_fn(params) + grad = jax.lax.pmean(grad, "batch") + + new_unet_state = unet_state.apply_gradients(grads=grad["unet"]) + if args.train_text_encoder: + new_text_encoder_state = text_encoder_state.apply_gradients(grads=grad["text_encoder"]) + else: + new_text_encoder_state = text_encoder_state + + metrics = {"loss": loss} + metrics = jax.lax.pmean(metrics, axis_name="batch") + + return new_unet_state, new_text_encoder_state, metrics, new_train_rng + + # Create parallel version of the train step + p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0, 1)) + + # Replicate the train state on each device + unet_state = jax_utils.replicate(unet_state) + text_encoder_state = jax_utils.replicate(text_encoder_state) + vae_params = jax_utils.replicate(vae_params) + + # Train! + num_update_steps_per_epoch = math.ceil(len(train_dataloader)) + + # Scheduler and math around the number of training steps. + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel & distributed) = {total_train_batch_size}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + + global_step = 0 + + epochs = tqdm(range(args.num_train_epochs), desc="Epoch ... ", position=0) + for epoch in epochs: + # ======================== Training ================================ + + train_metrics = [] + + steps_per_epoch = len(train_dataset) // total_train_batch_size + train_step_progress_bar = tqdm(total=steps_per_epoch, desc="Training...", position=1, leave=False) + # train + for batch in train_dataloader: + batch = shard(batch) + unet_state, text_encoder_state, train_metric, train_rngs = p_train_step( + unet_state, text_encoder_state, vae_params, batch, train_rngs + ) + train_metrics.append(train_metric) + + train_step_progress_bar.update(1) + + global_step += 1 + if global_step >= args.max_train_steps: + break + + train_metric = jax_utils.unreplicate(train_metric) + + train_step_progress_bar.close() + epochs.write(f"Epoch... ({epoch + 1}/{args.num_train_epochs} | Loss: {train_metric['loss']})") + + # Create the pipeline using using the trained modules and save it. + if jax.process_index() == 0: + scheduler = FlaxPNDMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True + ) + safety_checker = FlaxStableDiffusionSafetyChecker.from_pretrained( + "CompVis/stable-diffusion-safety-checker", from_pt=True + ) + pipeline = FlaxStableDiffusionPipeline( + text_encoder=text_encoder, + vae=vae, + unet=unet, + tokenizer=tokenizer, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), + ) + + pipeline.save_pretrained( + args.output_dir, + params={ + "text_encoder": get_params_to_save(text_encoder_state.params), + "vae": get_params_to_save(vae_params), + "unet": get_params_to_save(unet_state.params), + "safety_checker": safety_checker.params, + }, + ) + + if args.push_to_hub: + repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) + + +if __name__ == "__main__": + main() diff --git a/examples/rl/README.md b/examples/rl/README.md new file mode 100644 index 000000000000..d68f2bf780e3 --- /dev/null +++ b/examples/rl/README.md @@ -0,0 +1,19 @@ +# Overview + +These examples show how to run (Diffuser)[https://arxiv.org/abs/2205.09991] in Diffusers. +There are four scripts, +1. `run_diffuser_locomotion.py` to sample actions and run them in the environment, +2. and `run_diffuser_gen_trajectories.py` to just sample actions from the pre-trained diffusion model. + +You will need some RL specific requirements to run the examples: + +``` +pip install -f https://download.pytorch.org/whl/torch_stable.html \ + free-mujoco-py \ + einops \ + gym==0.24.1 \ + protobuf==3.20.1 \ + git+https://github.com/rail-berkeley/d4rl.git \ + mediapy \ + Pillow==9.0.0 +``` diff --git a/examples/rl/run_diffuser_gen_trajectories.py b/examples/rl/run_diffuser_gen_trajectories.py new file mode 100644 index 000000000000..5bb068cc9fc7 --- /dev/null +++ b/examples/rl/run_diffuser_gen_trajectories.py @@ -0,0 +1,57 @@ +import d4rl # noqa +import gym +import tqdm +from diffusers.experimental import ValueGuidedRLPipeline + + +config = dict( + n_samples=64, + horizon=32, + num_inference_steps=20, + n_guide_steps=0, + scale_grad_by_std=True, + scale=0.1, + eta=0.0, + t_grad_cutoff=2, + device="cpu", +) + + +if __name__ == "__main__": + env_name = "hopper-medium-v2" + env = gym.make(env_name) + + pipeline = ValueGuidedRLPipeline.from_pretrained( + "bglick13/hopper-medium-v2-value-function-hor32", + env=env, + ) + + env.seed(0) + obs = env.reset() + total_reward = 0 + total_score = 0 + T = 1000 + rollout = [obs.copy()] + try: + for t in tqdm.tqdm(range(T)): + # Call the policy + denorm_actions = pipeline(obs, planning_horizon=32) + + # execute action in environment + next_observation, reward, terminal, _ = env.step(denorm_actions) + score = env.get_normalized_score(total_reward) + # update return + total_reward += reward + total_score += score + print( + f"Step: {t}, Reward: {reward}, Total Reward: {total_reward}, Score: {score}, Total Score:" + f" {total_score}" + ) + # save observations for rendering + rollout.append(next_observation.copy()) + + obs = next_observation + except KeyboardInterrupt: + pass + + print(f"Total reward: {total_reward}") diff --git a/examples/rl/run_diffuser_locomotion.py b/examples/rl/run_diffuser_locomotion.py new file mode 100644 index 000000000000..e89181610b33 --- /dev/null +++ b/examples/rl/run_diffuser_locomotion.py @@ -0,0 +1,57 @@ +import d4rl # noqa +import gym +import tqdm +from diffusers.experimental import ValueGuidedRLPipeline + + +config = dict( + n_samples=64, + horizon=32, + num_inference_steps=20, + n_guide_steps=2, + scale_grad_by_std=True, + scale=0.1, + eta=0.0, + t_grad_cutoff=2, + device="cpu", +) + + +if __name__ == "__main__": + env_name = "hopper-medium-v2" + env = gym.make(env_name) + + pipeline = ValueGuidedRLPipeline.from_pretrained( + "bglick13/hopper-medium-v2-value-function-hor32", + env=env, + ) + + env.seed(0) + obs = env.reset() + total_reward = 0 + total_score = 0 + T = 1000 + rollout = [obs.copy()] + try: + for t in tqdm.tqdm(range(T)): + # call the policy + denorm_actions = pipeline(obs, planning_horizon=32) + + # execute action in environment + next_observation, reward, terminal, _ = env.step(denorm_actions) + score = env.get_normalized_score(total_reward) + # update return + total_reward += reward + total_score += score + print( + f"Step: {t}, Reward: {reward}, Total Reward: {total_reward}, Score: {score}, Total Score:" + f" {total_score}" + ) + # save observations for rendering + rollout.append(next_observation.copy()) + + obs = next_observation + except KeyboardInterrupt: + pass + + print(f"Total reward: {total_reward}") diff --git a/examples/test_examples.py b/examples/test_examples.py index 15c8e05a5c12..eb86b18b75f0 100644 --- a/examples/test_examples.py +++ b/examples/test_examples.py @@ -101,7 +101,7 @@ def test_textual_inversion(self): with tempfile.TemporaryDirectory() as tmpdir: test_args = f""" examples/textual_inversion/textual_inversion.py - --pretrained_model_name_or_path CompVis/stable-diffusion-v1-4 + --pretrained_model_name_or_path runwayml/stable-diffusion-v1-5 --train_data_dir docs/source/imgs --learnable_property object --placeholder_token diff --git a/examples/text_to_image/README.md b/examples/text_to_image/README.md index 6aca642cda4a..abe21875848a 100644 --- a/examples/text_to_image/README.md +++ b/examples/text_to_image/README.md @@ -4,10 +4,10 @@ The `train_text_to_image.py` script shows how to fine-tune stable diffusion mode ___Note___: -___This script is experimental. The script fine-tunes the whole model and often times the model overifits and runs into issues like catastrophic forgetting. It's recommended to try different hyperparamters to get the best result on your dataset.___ +___This script is experimental. The script fine-tunes the whole model and often times the model overfits and runs into issues like catastrophic forgetting. It's recommended to try different hyperparamters to get the best result on your dataset.___ -## Running locally +## Running locally with PyTorch ### Installing the dependencies Before running the scripts, make sure to install the library's training dependencies: @@ -46,7 +46,7 @@ With `gradient_checkpointing` and `mixed_precision` it should be possible to fin export MODEL_NAME="CompVis/stable-diffusion-v1-4" export dataset_name="lambdalabs/pokemon-blip-captions" -accelerate launch train_text_to_image.py \ +accelerate launch --mixed_precision="fp16" train_text_to_image.py \ --pretrained_model_name_or_path=$MODEL_NAME \ --dataset_name=$dataset_name \ --use_ema \ @@ -54,7 +54,6 @@ accelerate launch train_text_to_image.py \ --train_batch_size=1 \ --gradient_accumulation_steps=4 \ --gradient_checkpointing \ - --mixed_precision="fp16" \ --max_train_steps=15000 \ --learning_rate=1e-05 \ --max_grad_norm=1 \ @@ -70,7 +69,7 @@ If you wish to use custom loading logic, you should modify the script, we have l export MODEL_NAME="CompVis/stable-diffusion-v1-4" export TRAIN_DIR="path_to_your_dataset" -accelerate launch train_text_to_image.py \ +accelerate launch --mixed_precision="fp16" train_text_to_image.py \ --pretrained_model_name_or_path=$MODEL_NAME \ --train_data_dir=$TRAIN_DIR \ --use_ema \ @@ -78,7 +77,6 @@ accelerate launch train_text_to_image.py \ --train_batch_size=1 \ --gradient_accumulation_steps=4 \ --gradient_checkpointing \ - --mixed_precision="fp16" \ --max_train_steps=15000 \ --learning_rate=1e-05 \ --max_grad_norm=1 \ @@ -86,6 +84,7 @@ accelerate launch train_text_to_image.py \ --output_dir="sd-pokemon-model" ``` + Once the training is finished the model will be saved in the `output_dir` specified in the command. In this example it's `sd-pokemon-model`. To load the fine-tuned model for inference just pass that path to `StableDiffusionPipeline` @@ -99,3 +98,54 @@ pipe.to("cuda") image = pipe(prompt="yoda").images[0] image.save("yoda-pokemon.png") ``` + + + +## Training with Flax/JAX + +For faster training on TPUs and GPUs you can leverage the flax training example. Follow the instructions above to get the model and dataset before running the script. + +____Note: The flax example don't yet support features like gradient checkpoint, gradient accumulation etc, so to use flax for faster training we will need >30GB cards.___ + + +Before running the scripts, make sure to install the library's training dependencies: + +```bash +pip install -U -r requirements_flax.txt +``` + +```bash +export MODEL_NAME="duongna/stable-diffusion-v1-4-flax" +export dataset_name="lambdalabs/pokemon-blip-captions" + +python train_text_to_image_flax.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --dataset_name=$dataset_name \ + --resolution=512 --center_crop --random_flip \ + --train_batch_size=1 \ + --mixed_precision="fp16" \ + --max_train_steps=15000 \ + --learning_rate=1e-05 \ + --max_grad_norm=1 \ + --output_dir="sd-pokemon-model" +``` + + +To run on your own training files prepare the dataset according to the format required by `datasets`, you can find the instructions for how to do that in this [document](https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder-with-metadata). +If you wish to use custom loading logic, you should modify the script, we have left pointers for that in the training script. + +```bash +export MODEL_NAME="duongna/stable-diffusion-v1-4-flax" +export TRAIN_DIR="path_to_your_dataset" + +python train_text_to_image_flax.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$TRAIN_DIR \ + --resolution=512 --center_crop --random_flip \ + --train_batch_size=1 \ + --mixed_precision="fp16" \ + --max_train_steps=15000 \ + --learning_rate=1e-05 \ + --max_grad_norm=1 \ + --output_dir="sd-pokemon-model" +``` diff --git a/examples/text_to_image/requirements_flax.txt b/examples/text_to_image/requirements_flax.txt new file mode 100644 index 000000000000..a1adc6632f4c --- /dev/null +++ b/examples/text_to_image/requirements_flax.txt @@ -0,0 +1,9 @@ +diffusers>==0.5.1 +transformers>=4.21.0 +flax +optax +torch +torchvision +ftfy +tensorboard +modelcards \ No newline at end of file diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index c3c3c75438d0..88da2a55096a 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -186,12 +186,12 @@ def parse_args(): parser.add_argument( "--mixed_precision", type=str, - default="no", + default=None, choices=["no", "fp16", "bf16"], help=( - "Whether to use mixed precision. Choose" - "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." - "and an Nvidia Ampere GPU." + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." ), ) parser.add_argument( @@ -372,11 +372,7 @@ def main(): weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, ) - - # TODO (patil-suraj): load scheduler using args - noise_scheduler = DDPMScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, tensor_format="pt" - ) + noise_scheduler = DDPMScheduler.from_config(args.pretrained_model_name_or_path, subfolder="scheduler") # Get the datasets: you can either provide your own training and evaluation files (see below) # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). @@ -500,9 +496,9 @@ def collate_fn(examples): ) weight_dtype = torch.float32 - if args.mixed_precision == "fp16": + if accelerator.mixed_precision == "fp16": weight_dtype = torch.float16 - elif args.mixed_precision == "bf16": + elif accelerator.mixed_precision == "bf16": weight_dtype = torch.bfloat16 # Move text_encode and vae to gpu. @@ -609,9 +605,7 @@ def collate_fn(examples): vae=vae, unet=unet, tokenizer=tokenizer, - scheduler=PNDMScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True - ), + scheduler=PNDMScheduler.from_config(args.pretrained_model_name_or_path, subfolder="scheduler"), safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"), feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), ) diff --git a/examples/text_to_image/train_text_to_image_flax.py b/examples/text_to_image/train_text_to_image_flax.py new file mode 100644 index 000000000000..89a8dec7289e --- /dev/null +++ b/examples/text_to_image/train_text_to_image_flax.py @@ -0,0 +1,560 @@ +import argparse +import logging +import math +import os +import random +from pathlib import Path +from typing import Optional + +import numpy as np +import torch +import torch.utils.checkpoint + +import jax +import jax.numpy as jnp +import optax +import transformers +from datasets import load_dataset +from diffusers import ( + FlaxAutoencoderKL, + FlaxDDPMScheduler, + FlaxPNDMScheduler, + FlaxStableDiffusionPipeline, + FlaxUNet2DConditionModel, +) +from diffusers.pipelines.stable_diffusion import FlaxStableDiffusionSafetyChecker +from flax import jax_utils +from flax.training import train_state +from flax.training.common_utils import shard +from huggingface_hub import HfFolder, Repository, whoami +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel, set_seed + + +logger = logging.getLogger(__name__) + + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that 🤗 Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--train_data_dir", + type=str, + default=None, + help=( + "A folder containing the training data. Folder contents must follow the structure described in" + " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" + " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." + ), + ) + parser.add_argument( + "--image_column", type=str, default="image", help="The column of the dataset containing an image." + ) + parser.add_argument( + "--caption_column", + type=str, + default="text", + help="The column of the dataset containing a caption or a list of captions.", + ) + parser.add_argument( + "--max_train_samples", + type=int, + default=None, + help=( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ), + ) + parser.add_argument( + "--output_dir", + type=str, + default="sd-model-finetuned", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument("--seed", type=int, default=0, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", + action="store_true", + help="Whether to center crop images before resizing to resolution (if not set, random crop will be used)", + ) + parser.add_argument( + "--random_flip", + action="store_true", + help="whether to randomly flip images horizontally", + ) + parser.add_argument( + "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument("--num_train_epochs", type=int, default=100) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,' + ' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.' + "Only applicable when `--with_tracking` is passed." + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default="no", + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose" + "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." + "and an Nvidia Ampere GPU." + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + + args = parser.parse_args() + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + # Sanity checks + if args.dataset_name is None and args.train_data_dir is None: + raise ValueError("Need either a dataset name or a training folder.") + + return args + + +def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): + if token is None: + token = HfFolder.get_token() + if organization is None: + username = whoami(token)["name"] + return f"{username}/{model_id}" + else: + return f"{organization}/{model_id}" + + +dataset_name_mapping = { + "lambdalabs/pokemon-blip-captions": ("image", "text"), +} + + +def get_params_to_save(params): + return jax.device_get(jax.tree_util.tree_map(lambda x: x[0], params)) + + +def main(): + args = parse_args() + + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + # Setup logging, we only want one process per machine to log things on the screen. + logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR) + if jax.process_index() == 0: + transformers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if jax.process_index() == 0: + if args.push_to_hub: + if args.hub_model_id is None: + repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) + else: + repo_name = args.hub_model_id + repo = Repository(args.output_dir, clone_from=repo_name) + + with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: + if "step_*" not in gitignore: + gitignore.write("step_*\n") + if "epoch_*" not in gitignore: + gitignore.write("epoch_*\n") + elif args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + # Get the datasets: you can either provide your own training and evaluation files (see below) + # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). + + # In distributed training, the load_dataset function guarantees that only one local process can concurrently + # download the dataset. + if args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + dataset = load_dataset( + args.dataset_name, + args.dataset_config_name, + cache_dir=args.cache_dir, + ) + else: + data_files = {} + if args.train_data_dir is not None: + data_files["train"] = os.path.join(args.train_data_dir, "**") + dataset = load_dataset( + "imagefolder", + data_files=data_files, + cache_dir=args.cache_dir, + ) + # See more about loading custom images at + # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder + + # Preprocessing the datasets. + # We need to tokenize inputs and targets. + column_names = dataset["train"].column_names + + # 6. Get the column names for input/target. + dataset_columns = dataset_name_mapping.get(args.dataset_name, None) + if args.image_column is None: + image_column = dataset_columns[0] if dataset_columns is not None else column_names[0] + else: + image_column = args.image_column + if image_column not in column_names: + raise ValueError( + f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}" + ) + if args.caption_column is None: + caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1] + else: + caption_column = args.caption_column + if caption_column not in column_names: + raise ValueError( + f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}" + ) + + # Preprocessing the datasets. + # We need to tokenize input captions and transform the images. + def tokenize_captions(examples, is_train=True): + captions = [] + for caption in examples[caption_column]: + if isinstance(caption, str): + captions.append(caption) + elif isinstance(caption, (list, np.ndarray)): + # take a random caption if there are multiple + captions.append(random.choice(caption) if is_train else caption[0]) + else: + raise ValueError( + f"Caption column `{caption_column}` should contain either strings or lists of strings." + ) + inputs = tokenizer(captions, max_length=tokenizer.model_max_length, padding="do_not_pad", truncation=True) + input_ids = inputs.input_ids + return input_ids + + train_transforms = transforms.Compose( + [ + transforms.Resize((args.resolution, args.resolution), interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), + transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + def preprocess_train(examples): + images = [image.convert("RGB") for image in examples[image_column]] + examples["pixel_values"] = [train_transforms(image) for image in images] + examples["input_ids"] = tokenize_captions(examples) + + return examples + + if jax.process_index() == 0: + if args.max_train_samples is not None: + dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples)) + # Set the training transforms + train_dataset = dataset["train"].with_transform(preprocess_train) + + def collate_fn(examples): + pixel_values = torch.stack([example["pixel_values"] for example in examples]) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + input_ids = [example["input_ids"] for example in examples] + + padded_tokens = tokenizer.pad( + {"input_ids": input_ids}, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt" + ) + batch = { + "pixel_values": pixel_values, + "input_ids": padded_tokens.input_ids, + } + batch = {k: v.numpy() for k, v in batch.items()} + + return batch + + total_train_batch_size = args.train_batch_size * jax.local_device_count() + train_dataloader = torch.utils.data.DataLoader( + train_dataset, shuffle=True, collate_fn=collate_fn, batch_size=total_train_batch_size, drop_last=True + ) + + weight_dtype = jnp.float32 + if args.mixed_precision == "fp16": + weight_dtype = jnp.float16 + elif args.mixed_precision == "bf16": + weight_dtype = jnp.bfloat16 + + # Load models and create wrapper for stable diffusion + tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer") + text_encoder = FlaxCLIPTextModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", dtype=weight_dtype + ) + vae, vae_params = FlaxAutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, subfolder="vae", dtype=weight_dtype + ) + unet, unet_params = FlaxUNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="unet", dtype=weight_dtype + ) + + # Optimization + if args.scale_lr: + args.learning_rate = args.learning_rate * total_train_batch_size + + constant_scheduler = optax.constant_schedule(args.learning_rate) + + adamw = optax.adamw( + learning_rate=constant_scheduler, + b1=args.adam_beta1, + b2=args.adam_beta2, + eps=args.adam_epsilon, + weight_decay=args.adam_weight_decay, + ) + + optimizer = optax.chain( + optax.clip_by_global_norm(args.max_grad_norm), + adamw, + ) + + state = train_state.TrainState.create(apply_fn=unet.__call__, params=unet_params, tx=optimizer) + + noise_scheduler = FlaxDDPMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000 + ) + + # Initialize our training + rng = jax.random.PRNGKey(args.seed) + train_rngs = jax.random.split(rng, jax.local_device_count()) + + def train_step(state, text_encoder_params, vae_params, batch, train_rng): + dropout_rng, sample_rng, new_train_rng = jax.random.split(train_rng, 3) + + def compute_loss(params): + # Convert images to latent space + vae_outputs = vae.apply( + {"params": vae_params}, batch["pixel_values"], deterministic=True, method=vae.encode + ) + latents = vae_outputs.latent_dist.sample(sample_rng) + # (NHWC) -> (NCHW) + latents = jnp.transpose(latents, (0, 3, 1, 2)) + latents = latents * 0.18215 + + # Sample noise that we'll add to the latents + noise_rng, timestep_rng = jax.random.split(sample_rng) + noise = jax.random.normal(noise_rng, latents.shape) + # Sample a random timestep for each image + bsz = latents.shape[0] + timesteps = jax.random.randint( + timestep_rng, + (bsz,), + 0, + noise_scheduler.config.num_train_timesteps, + ) + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Get the text embedding for conditioning + encoder_hidden_states = text_encoder( + batch["input_ids"], + params=text_encoder_params, + train=False, + )[0] + + # Predict the noise residual and compute loss + unet_outputs = unet.apply({"params": params}, noisy_latents, timesteps, encoder_hidden_states, train=True) + noise_pred = unet_outputs.sample + loss = (noise - noise_pred) ** 2 + loss = loss.mean() + + return loss + + grad_fn = jax.value_and_grad(compute_loss) + loss, grad = grad_fn(state.params) + grad = jax.lax.pmean(grad, "batch") + + new_state = state.apply_gradients(grads=grad) + + metrics = {"loss": loss} + metrics = jax.lax.pmean(metrics, axis_name="batch") + + return new_state, metrics, new_train_rng + + # Create parallel version of the train step + p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,)) + + # Replicate the train state on each device + state = jax_utils.replicate(state) + text_encoder_params = jax_utils.replicate(text_encoder.params) + vae_params = jax_utils.replicate(vae_params) + + # Train! + num_update_steps_per_epoch = math.ceil(len(train_dataloader)) + + # Scheduler and math around the number of training steps. + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel & distributed) = {total_train_batch_size}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + + global_step = 0 + + epochs = tqdm(range(args.num_train_epochs), desc="Epoch ... ", position=0) + for epoch in epochs: + # ======================== Training ================================ + + train_metrics = [] + + steps_per_epoch = len(train_dataset) // total_train_batch_size + train_step_progress_bar = tqdm(total=steps_per_epoch, desc="Training...", position=1, leave=False) + # train + for batch in train_dataloader: + batch = shard(batch) + state, train_metric, train_rngs = p_train_step(state, text_encoder_params, vae_params, batch, train_rngs) + train_metrics.append(train_metric) + + train_step_progress_bar.update(1) + + global_step += 1 + if global_step >= args.max_train_steps: + break + + train_metric = jax_utils.unreplicate(train_metric) + + train_step_progress_bar.close() + epochs.write(f"Epoch... ({epoch + 1}/{args.num_train_epochs} | Loss: {train_metric['loss']})") + + # Create the pipeline using using the trained modules and save it. + if jax.process_index() == 0: + scheduler = FlaxPNDMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True + ) + safety_checker = FlaxStableDiffusionSafetyChecker.from_pretrained( + "CompVis/stable-diffusion-safety-checker", from_pt=True + ) + pipeline = FlaxStableDiffusionPipeline( + text_encoder=text_encoder, + vae=vae, + unet=unet, + tokenizer=tokenizer, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), + ) + + pipeline.save_pretrained( + args.output_dir, + params={ + "text_encoder": get_params_to_save(text_encoder_params), + "vae": get_params_to_save(vae_params), + "unet": get_params_to_save(state.params), + "safety_checker": safety_checker.params, + }, + ) + + if args.push_to_hub: + repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) + + +if __name__ == "__main__": + main() diff --git a/examples/textual_inversion/README.md b/examples/textual_inversion/README.md index 05d8ffb8c9f2..2edf34cb490d 100644 --- a/examples/textual_inversion/README.md +++ b/examples/textual_inversion/README.md @@ -11,7 +11,7 @@ Colab for training Colab for inference [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_conceptualizer_inference.ipynb) -## Running locally +## Running locally with PyTorch ### Installing the dependencies Before running the scripts, make sure to install the library's training dependencies: @@ -29,7 +29,7 @@ accelerate config ### Cat toy example -You need to accept the model license before downloading or using the weights. In this example we'll use model version `v1-4`, so you'll need to visit [its card](https://huggingface.co/CompVis/stable-diffusion-v1-4), read the license and tick the checkbox if you agree. +You need to accept the model license before downloading or using the weights. In this example we'll use model version `v1-5`, so you'll need to visit [its card](https://huggingface.co/runwayml/stable-diffusion-v1-5), read the license and tick the checkbox if you agree. You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to [this section of the documentation](https://huggingface.co/docs/hub/security-tokens). @@ -48,7 +48,7 @@ Now let's get our dataset.Download 3-4 images from [here](https://drive.google.c And launch the training using ```bash -export MODEL_NAME="CompVis/stable-diffusion-v1-4" +export MODEL_NAME="runwayml/stable-diffusion-v1-5" export DATA_DIR="path-to-dir-containing-images" accelerate launch textual_inversion.py \ @@ -68,7 +68,6 @@ accelerate launch textual_inversion.py \ A full training run takes ~1 hour on one V100 GPU. - ### Inference Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline`. Make sure to include the `placeholder_token` in your prompt. @@ -85,3 +84,31 @@ image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0] image.save("cat-backpack.png") ``` + + +## Training with Flax/JAX + +For faster training on TPUs and GPUs you can leverage the flax training example. Follow the instructions above to get the model and dataset before running the script. + +Before running the scripts, make sure to install the library's training dependencies: + +```bash +pip install -U -r requirements_flax.txt +``` + +```bash +export MODEL_NAME="duongna/stable-diffusion-v1-4-flax" +export DATA_DIR="path-to-dir-containing-images" + +python textual_inversion_flax.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATA_DIR \ + --learnable_property="object" \ + --placeholder_token="" --initializer_token="toy" \ + --resolution=512 \ + --train_batch_size=1 \ + --max_train_steps=3000 \ + --learning_rate=5.0e-04 --scale_lr \ + --output_dir="textual_inversion_cat" +``` +It should be at least 70% faster than the PyTorch script with the same configuration. diff --git a/examples/textual_inversion/requirements_flax.txt b/examples/textual_inversion/requirements_flax.txt new file mode 100644 index 000000000000..a1adc6632f4c --- /dev/null +++ b/examples/textual_inversion/requirements_flax.txt @@ -0,0 +1,9 @@ +diffusers>==0.5.1 +transformers>=4.21.0 +flax +optax +torch +torchvision +ftfy +tensorboard +modelcards \ No newline at end of file diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py index 18469af3d4ae..7d9fb7c0f175 100644 --- a/examples/textual_inversion/textual_inversion.py +++ b/examples/textual_inversion/textual_inversion.py @@ -20,12 +20,34 @@ from diffusers.optimization import get_scheduler from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker from huggingface_hub import HfFolder, Repository, whoami + +# TODO: remove and import from diffusers.utils when the new version of diffusers is released +from packaging import version from PIL import Image from torchvision import transforms from tqdm.auto import tqdm from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): + PIL_INTERPOLATION = { + "linear": PIL.Image.Resampling.BILINEAR, + "bilinear": PIL.Image.Resampling.BILINEAR, + "bicubic": PIL.Image.Resampling.BICUBIC, + "lanczos": PIL.Image.Resampling.LANCZOS, + "nearest": PIL.Image.Resampling.NEAREST, + } +else: + PIL_INTERPOLATION = { + "linear": PIL.Image.LINEAR, + "bilinear": PIL.Image.BILINEAR, + "bicubic": PIL.Image.BICUBIC, + "lanczos": PIL.Image.LANCZOS, + "nearest": PIL.Image.NEAREST, + } +# ------------------------------------------------------------------------------ + + logger = get_logger(__name__) @@ -260,10 +282,10 @@ def __init__( self._length = self.num_images * repeats self.interpolation = { - "linear": PIL.Image.LINEAR, - "bilinear": PIL.Image.BILINEAR, - "bicubic": PIL.Image.BICUBIC, - "lanczos": PIL.Image.LANCZOS, + "linear": PIL_INTERPOLATION["linear"], + "bilinear": PIL_INTERPOLATION["bilinear"], + "bicubic": PIL_INTERPOLATION["bicubic"], + "lanczos": PIL_INTERPOLATION["lanczos"], }[interpolation] self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small @@ -419,13 +441,7 @@ def main(): eps=args.adam_epsilon, ) - # TODO (patil-suraj): load scheduler using args - noise_scheduler = DDPMScheduler( - beta_start=0.00085, - beta_end=0.012, - beta_schedule="scaled_linear", - num_train_timesteps=1000, - ) + noise_scheduler = DDPMScheduler.from_config(args.pretrained_model_name_or_path, subfolder="scheduler") train_dataset = TextualInversionDataset( data_root=args.train_data_dir, @@ -558,9 +574,7 @@ def main(): vae=vae, unet=unet, tokenizer=tokenizer, - scheduler=PNDMScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True - ), + scheduler=PNDMScheduler.from_config(args.pretrained_model_name_or_path, subfolder="scheduler"), safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"), feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), ) diff --git a/examples/textual_inversion/textual_inversion_flax.py b/examples/textual_inversion/textual_inversion_flax.py new file mode 100644 index 000000000000..6406be8ad698 --- /dev/null +++ b/examples/textual_inversion/textual_inversion_flax.py @@ -0,0 +1,650 @@ +import argparse +import logging +import math +import os +import random +from pathlib import Path +from typing import Optional + +import numpy as np +import torch +import torch.utils.checkpoint +from torch.utils.data import Dataset + +import jax +import jax.numpy as jnp +import optax +import PIL +import transformers +from diffusers import ( + FlaxAutoencoderKL, + FlaxDDPMScheduler, + FlaxPNDMScheduler, + FlaxStableDiffusionPipeline, + FlaxUNet2DConditionModel, +) +from diffusers.pipelines.stable_diffusion import FlaxStableDiffusionSafetyChecker +from flax import jax_utils +from flax.training import train_state +from flax.training.common_utils import shard +from huggingface_hub import HfFolder, Repository, whoami + +# TODO: remove and import from diffusers.utils when the new version of diffusers is released +from packaging import version +from PIL import Image +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel, set_seed + + +if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): + PIL_INTERPOLATION = { + "linear": PIL.Image.Resampling.BILINEAR, + "bilinear": PIL.Image.Resampling.BILINEAR, + "bicubic": PIL.Image.Resampling.BICUBIC, + "lanczos": PIL.Image.Resampling.LANCZOS, + "nearest": PIL.Image.Resampling.NEAREST, + } +else: + PIL_INTERPOLATION = { + "linear": PIL.Image.LINEAR, + "bilinear": PIL.Image.BILINEAR, + "bicubic": PIL.Image.BICUBIC, + "lanczos": PIL.Image.LANCZOS, + "nearest": PIL.Image.NEAREST, + } +# ------------------------------------------------------------------------------ + +logger = logging.getLogger(__name__) + + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--train_data_dir", type=str, default=None, required=True, help="A folder containing the training data." + ) + parser.add_argument( + "--placeholder_token", + type=str, + default=None, + required=True, + help="A token to use as a placeholder for the concept.", + ) + parser.add_argument( + "--initializer_token", type=str, default=None, required=True, help="A token to use as initializer word." + ) + parser.add_argument("--learnable_property", type=str, default="object", help="Choose between 'object' and 'style'") + parser.add_argument("--repeats", type=int, default=100, help="How many times to repeat the training data.") + parser.add_argument( + "--output_dir", + type=str, + default="text-inversion-model", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution" + ) + parser.add_argument( + "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument("--num_train_epochs", type=int, default=100) + parser.add_argument( + "--max_train_steps", + type=int, + default=5000, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=True, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument( + "--use_auth_token", + action="store_true", + help=( + "Will use the token generated when running `huggingface-cli login` (necessary to use this script with" + " private models)." + ), + ) + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + + args = parser.parse_args() + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + if args.train_data_dir is None: + raise ValueError("You must specify a train data directory.") + + return args + + +imagenet_templates_small = [ + "a photo of a {}", + "a rendering of a {}", + "a cropped photo of the {}", + "the photo of a {}", + "a photo of a clean {}", + "a photo of a dirty {}", + "a dark photo of the {}", + "a photo of my {}", + "a photo of the cool {}", + "a close-up photo of a {}", + "a bright photo of the {}", + "a cropped photo of a {}", + "a photo of the {}", + "a good photo of the {}", + "a photo of one {}", + "a close-up photo of the {}", + "a rendition of the {}", + "a photo of the clean {}", + "a rendition of a {}", + "a photo of a nice {}", + "a good photo of a {}", + "a photo of the nice {}", + "a photo of the small {}", + "a photo of the weird {}", + "a photo of the large {}", + "a photo of a cool {}", + "a photo of a small {}", +] + +imagenet_style_templates_small = [ + "a painting in the style of {}", + "a rendering in the style of {}", + "a cropped painting in the style of {}", + "the painting in the style of {}", + "a clean painting in the style of {}", + "a dirty painting in the style of {}", + "a dark painting in the style of {}", + "a picture in the style of {}", + "a cool painting in the style of {}", + "a close-up painting in the style of {}", + "a bright painting in the style of {}", + "a cropped painting in the style of {}", + "a good painting in the style of {}", + "a close-up painting in the style of {}", + "a rendition in the style of {}", + "a nice painting in the style of {}", + "a small painting in the style of {}", + "a weird painting in the style of {}", + "a large painting in the style of {}", +] + + +class TextualInversionDataset(Dataset): + def __init__( + self, + data_root, + tokenizer, + learnable_property="object", # [object, style] + size=512, + repeats=100, + interpolation="bicubic", + flip_p=0.5, + set="train", + placeholder_token="*", + center_crop=False, + ): + self.data_root = data_root + self.tokenizer = tokenizer + self.learnable_property = learnable_property + self.size = size + self.placeholder_token = placeholder_token + self.center_crop = center_crop + self.flip_p = flip_p + + self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)] + + self.num_images = len(self.image_paths) + self._length = self.num_images + + if set == "train": + self._length = self.num_images * repeats + + self.interpolation = { + "linear": PIL_INTERPOLATION["linear"], + "bilinear": PIL_INTERPOLATION["bilinear"], + "bicubic": PIL_INTERPOLATION["bicubic"], + "lanczos": PIL_INTERPOLATION["lanczos"], + }[interpolation] + + self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small + self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p) + + def __len__(self): + return self._length + + def __getitem__(self, i): + example = {} + image = Image.open(self.image_paths[i % self.num_images]) + + if not image.mode == "RGB": + image = image.convert("RGB") + + placeholder_string = self.placeholder_token + text = random.choice(self.templates).format(placeholder_string) + + example["input_ids"] = self.tokenizer( + text, + padding="max_length", + truncation=True, + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ).input_ids[0] + + # default to score-sde preprocessing + img = np.array(image).astype(np.uint8) + + if self.center_crop: + crop = min(img.shape[0], img.shape[1]) + h, w, = ( + img.shape[0], + img.shape[1], + ) + img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2] + + image = Image.fromarray(img) + image = image.resize((self.size, self.size), resample=self.interpolation) + + image = self.flip_transform(image) + image = np.array(image).astype(np.uint8) + image = (image / 127.5 - 1.0).astype(np.float32) + + example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1) + return example + + +def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): + if token is None: + token = HfFolder.get_token() + if organization is None: + username = whoami(token)["name"] + return f"{username}/{model_id}" + else: + return f"{organization}/{model_id}" + + +def resize_token_embeddings(model, new_num_tokens, initializer_token_id, placeholder_token_id, rng): + if model.config.vocab_size == new_num_tokens or new_num_tokens is None: + return + model.config.vocab_size = new_num_tokens + + params = model.params + old_embeddings = params["text_model"]["embeddings"]["token_embedding"]["embedding"] + old_num_tokens, emb_dim = old_embeddings.shape + + initializer = jax.nn.initializers.normal() + + new_embeddings = initializer(rng, (new_num_tokens, emb_dim)) + new_embeddings = new_embeddings.at[:old_num_tokens].set(old_embeddings) + new_embeddings = new_embeddings.at[placeholder_token_id].set(new_embeddings[initializer_token_id]) + params["text_model"]["embeddings"]["token_embedding"]["embedding"] = new_embeddings + + model.params = params + return model + + +def get_params_to_save(params): + return jax.device_get(jax.tree_util.tree_map(lambda x: x[0], params)) + + +def main(): + args = parse_args() + + if args.seed is not None: + set_seed(args.seed) + + if jax.process_index() == 0: + if args.push_to_hub: + if args.hub_model_id is None: + repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) + else: + repo_name = args.hub_model_id + repo = Repository(args.output_dir, clone_from=repo_name) + + with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: + if "step_*" not in gitignore: + gitignore.write("step_*\n") + if "epoch_*" not in gitignore: + gitignore.write("epoch_*\n") + elif args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + # Setup logging, we only want one process per machine to log things on the screen. + logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR) + if jax.process_index() == 0: + transformers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + + # Load the tokenizer and add the placeholder token as a additional special token + if args.tokenizer_name: + tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) + elif args.pretrained_model_name_or_path: + tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer") + + # Add the placeholder token in tokenizer + num_added_tokens = tokenizer.add_tokens(args.placeholder_token) + if num_added_tokens == 0: + raise ValueError( + f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different" + " `placeholder_token` that is not already in the tokenizer." + ) + + # Convert the initializer_token, placeholder_token to ids + token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False) + # Check if initializer_token is a single token or a sequence of tokens + if len(token_ids) > 1: + raise ValueError("The initializer token must be a single token.") + + initializer_token_id = token_ids[0] + placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) + + # Load models and create wrapper for stable diffusion + text_encoder = FlaxCLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder") + vae, vae_params = FlaxAutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae") + unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet") + + # Create sampling rng + rng = jax.random.PRNGKey(args.seed) + rng, _ = jax.random.split(rng) + # Resize the token embeddings as we are adding new special tokens to the tokenizer + text_encoder = resize_token_embeddings( + text_encoder, len(tokenizer), initializer_token_id, placeholder_token_id, rng + ) + original_token_embeds = text_encoder.params["text_model"]["embeddings"]["token_embedding"]["embedding"] + + train_dataset = TextualInversionDataset( + data_root=args.train_data_dir, + tokenizer=tokenizer, + size=args.resolution, + placeholder_token=args.placeholder_token, + repeats=args.repeats, + learnable_property=args.learnable_property, + center_crop=args.center_crop, + set="train", + ) + + def collate_fn(examples): + pixel_values = torch.stack([example["pixel_values"] for example in examples]) + input_ids = torch.stack([example["input_ids"] for example in examples]) + + batch = {"pixel_values": pixel_values, "input_ids": input_ids} + batch = {k: v.numpy() for k, v in batch.items()} + + return batch + + total_train_batch_size = args.train_batch_size * jax.local_device_count() + train_dataloader = torch.utils.data.DataLoader( + train_dataset, batch_size=total_train_batch_size, shuffle=True, drop_last=True, collate_fn=collate_fn + ) + + # Optimization + if args.scale_lr: + args.learning_rate = args.learning_rate * total_train_batch_size + + constant_scheduler = optax.constant_schedule(args.learning_rate) + + optimizer = optax.adamw( + learning_rate=constant_scheduler, + b1=args.adam_beta1, + b2=args.adam_beta2, + eps=args.adam_epsilon, + weight_decay=args.adam_weight_decay, + ) + + def create_mask(params, label_fn): + def _map(params, mask, label_fn): + for k in params: + if label_fn(k): + mask[k] = "token_embedding" + else: + if isinstance(params[k], dict): + mask[k] = {} + _map(params[k], mask[k], label_fn) + else: + mask[k] = "zero" + + mask = {} + _map(params, mask, label_fn) + return mask + + def zero_grads(): + # from https://github.com/deepmind/optax/issues/159#issuecomment-896459491 + def init_fn(_): + return () + + def update_fn(updates, state, params=None): + return jax.tree_util.tree_map(jnp.zeros_like, updates), () + + return optax.GradientTransformation(init_fn, update_fn) + + # Zero out gradients of layers other than the token embedding layer + tx = optax.multi_transform( + {"token_embedding": optimizer, "zero": zero_grads()}, + create_mask(text_encoder.params, lambda s: s == "token_embedding"), + ) + + state = train_state.TrainState.create(apply_fn=text_encoder.__call__, params=text_encoder.params, tx=tx) + + noise_scheduler = FlaxDDPMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000 + ) + + # Initialize our training + train_rngs = jax.random.split(rng, jax.local_device_count()) + + # Define gradient train step fn + def train_step(state, vae_params, unet_params, batch, train_rng): + dropout_rng, sample_rng, new_train_rng = jax.random.split(train_rng, 3) + + def compute_loss(params): + vae_outputs = vae.apply( + {"params": vae_params}, batch["pixel_values"], deterministic=True, method=vae.encode + ) + latents = vae_outputs.latent_dist.sample(sample_rng) + # (NHWC) -> (NCHW) + latents = jnp.transpose(latents, (0, 3, 1, 2)) + latents = latents * 0.18215 + + noise_rng, timestep_rng = jax.random.split(sample_rng) + noise = jax.random.normal(noise_rng, latents.shape) + bsz = latents.shape[0] + timesteps = jax.random.randint( + timestep_rng, + (bsz,), + 0, + noise_scheduler.config.num_train_timesteps, + ) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + encoder_hidden_states = state.apply_fn( + batch["input_ids"], params=params, dropout_rng=dropout_rng, train=True + )[0] + unet_outputs = unet.apply( + {"params": unet_params}, noisy_latents, timesteps, encoder_hidden_states, train=False + ) + noise_pred = unet_outputs.sample + loss = (noise - noise_pred) ** 2 + loss = loss.mean() + + return loss + + grad_fn = jax.value_and_grad(compute_loss) + loss, grad = grad_fn(state.params) + grad = jax.lax.pmean(grad, "batch") + new_state = state.apply_gradients(grads=grad) + + # Keep the token embeddings fixed except the newly added embeddings for the concept, + # as we only want to optimize the concept embeddings + token_embeds = original_token_embeds.at[placeholder_token_id].set( + new_state.params["text_model"]["embeddings"]["token_embedding"]["embedding"][placeholder_token_id] + ) + new_state.params["text_model"]["embeddings"]["token_embedding"]["embedding"] = token_embeds + + metrics = {"loss": loss} + metrics = jax.lax.pmean(metrics, axis_name="batch") + return new_state, metrics, new_train_rng + + # Create parallel version of the train and eval step + p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,)) + + # Replicate the train state on each device + state = jax_utils.replicate(state) + vae_params = jax_utils.replicate(vae_params) + unet_params = jax_utils.replicate(unet_params) + + # Train! + num_update_steps_per_epoch = math.ceil(len(train_dataloader)) + + # Scheduler and math around the number of training steps. + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel & distributed) = {total_train_batch_size}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + + global_step = 0 + + epochs = tqdm(range(args.num_train_epochs), desc=f"Epoch ... (1/{args.num_train_epochs})", position=0) + for epoch in epochs: + # ======================== Training ================================ + + train_metrics = [] + + steps_per_epoch = len(train_dataset) // total_train_batch_size + train_step_progress_bar = tqdm(total=steps_per_epoch, desc="Training...", position=1, leave=False) + # train + for batch in train_dataloader: + batch = shard(batch) + state, train_metric, train_rngs = p_train_step(state, vae_params, unet_params, batch, train_rngs) + train_metrics.append(train_metric) + + train_step_progress_bar.update(1) + global_step += 1 + + if global_step >= args.max_train_steps: + break + + train_metric = jax_utils.unreplicate(train_metric) + + train_step_progress_bar.close() + epochs.write(f"Epoch... ({epoch + 1}/{args.num_train_epochs} | Loss: {train_metric['loss']})") + + # Create the pipeline using using the trained modules and save it. + if jax.process_index() == 0: + scheduler = FlaxPNDMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True + ) + safety_checker = FlaxStableDiffusionSafetyChecker.from_pretrained( + "CompVis/stable-diffusion-safety-checker", from_pt=True + ) + pipeline = FlaxStableDiffusionPipeline( + text_encoder=text_encoder, + vae=vae, + unet=unet, + tokenizer=tokenizer, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), + ) + + pipeline.save_pretrained( + args.output_dir, + params={ + "text_encoder": get_params_to_save(state.params), + "vae": get_params_to_save(vae_params), + "unet": get_params_to_save(unet_params), + "safety_checker": safety_checker.params, + }, + ) + + # Also save the newly trained embeddings + learned_embeds = get_params_to_save(state.params)["text_model"]["embeddings"]["token_embedding"]["embedding"][ + placeholder_token_id + ] + learned_embeds_dict = {args.placeholder_token: learned_embeds} + jnp.save(os.path.join(args.output_dir, "learned_embeds.npy"), learned_embeds_dict) + + if args.push_to_hub: + repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) + + +if __name__ == "__main__": + main() diff --git a/examples/unconditional_image_generation/README.md b/examples/unconditional_image_generation/README.md index a2a5a840af68..dbb849178982 100644 --- a/examples/unconditional_image_generation/README.md +++ b/examples/unconditional_image_generation/README.md @@ -7,7 +7,7 @@ Creating a training image set is [described in a different document](https://hug Before running the scripts, make sure to install the library's training dependencies: ```bash -pip install diffusers[training] accelerate datasets +pip install diffusers[training] accelerate datasets tensorboard ``` And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with: @@ -127,3 +127,24 @@ dataset.push_to_hub("name_of_your_dataset", private=True) and that's it! You can now train your model by simply setting the `--dataset_name` argument to the name of your dataset on the hub. More on this can also be found in [this blog post](https://huggingface.co/blog/image-search-datasets). + +#### Use ONNXRuntime to accelerate training + +In order to leverage onnxruntime to accelerate training, please use train_unconditional_ort.py + +The command to train a DDPM UNet model on the Oxford Flowers dataset with onnxruntime: + +```bash +accelerate launch train_unconditional_ort.py \ + --dataset_name="huggan/flowers-102-categories" \ + --resolution=64 \ + --output_dir="ddpm-ema-flowers-64" \ + --train_batch_size=16 \ + --num_epochs=1 \ + --gradient_accumulation_steps=1 \ + --learning_rate=1e-4 \ + --lr_warmup_steps=500 \ + --mixed_precision=fp16 + ``` + +Please contact Prathik Rao (prathikr), Sunghoon Choi (hanbitmyths), Ashwini Khade (askhade), or Peng Wang (pengwa) on github with any questions. \ No newline at end of file diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py index f9f8c85bd1e3..6abe46c57def 100644 --- a/examples/unconditional_image_generation/train_unconditional.py +++ b/examples/unconditional_image_generation/train_unconditional.py @@ -1,6 +1,9 @@ import argparse +import inspect import math import os +from pathlib import Path +from typing import Optional import torch import torch.nn.functional as F @@ -8,10 +11,12 @@ from accelerate import Accelerator from accelerate.logging import get_logger from datasets import load_dataset -from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel -from diffusers.hub_utils import init_git_repo +from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel, __version__ from diffusers.optimization import get_scheduler from diffusers.training_utils import EMAModel +from diffusers.utils import deprecate +from huggingface_hub import HfFolder, Repository, whoami +from packaging import version from torchvision.transforms import ( CenterCrop, Compose, @@ -25,6 +30,199 @@ logger = get_logger(__name__) +diffusers_version = version.parse(version.parse(__version__).base_version) + + +def _extract_into_tensor(arr, timesteps, broadcast_shape): + """ + Extract values from a 1-D numpy array for a batch of indices. + + :param arr: the 1-D numpy array. + :param timesteps: a tensor of indices into the array to extract. + :param broadcast_shape: a larger shape of K dimensions with the batch + dimension equal to the length of timesteps. + :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. + """ + if not isinstance(arr, torch.Tensor): + arr = torch.from_numpy(arr) + res = arr[timesteps].float().to(timesteps.device) + while len(res.shape) < len(broadcast_shape): + res = res[..., None] + return res.expand(broadcast_shape) + + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that HF Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--train_data_dir", + type=str, + default=None, + help=( + "A folder containing the training data. Folder contents must follow the structure described in" + " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" + " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." + ), + ) + parser.add_argument( + "--output_dir", + type=str, + default="ddpm-model-64", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--overwrite_output_dir", action="store_true") + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument( + "--resolution", + type=int, + default=64, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument( + "--eval_batch_size", type=int, default=16, help="The number of images to generate for evaluation." + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "The number of subprocesses to use for data loading. 0 means that the data will be loaded in the main" + " process." + ), + ) + parser.add_argument("--num_epochs", type=int, default=100) + parser.add_argument("--save_images_epochs", type=int, default=10, help="How often to save images during training.") + parser.add_argument( + "--save_model_epochs", type=int, default=10, help="How often to save the model during training." + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="cosine", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument("--adam_beta1", type=float, default=0.95, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument( + "--adam_weight_decay", type=float, default=1e-6, help="Weight decay magnitude for the Adam optimizer." + ) + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer.") + parser.add_argument( + "--use_ema", + action="store_true", + default=True, + help="Whether to use Exponential Moving Average for the final model weights.", + ) + parser.add_argument("--ema_inv_gamma", type=float, default=1.0, help="The inverse gamma value for the EMA decay.") + parser.add_argument("--ema_power", type=float, default=3 / 4, help="The power value for the EMA decay.") + parser.add_argument("--ema_max_decay", type=float, default=0.9999, help="The maximum decay magnitude for EMA.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--hub_private_repo", action="store_true", help="Whether or not to create a private repository." + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--mixed_precision", + type=str, + default="no", + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose" + "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." + "and an Nvidia Ampere GPU." + ), + ) + + parser.add_argument( + "--prediction_type", + type=str, + default="epsilon", + choices=["epsilon", "sample"], + help="Whether the model should predict the 'epsilon'/noise error or directly the reconstructed image 'x0'.", + ) + + parser.add_argument("--ddpm_num_steps", type=int, default=1000) + parser.add_argument("--ddpm_beta_schedule", type=str, default="linear") + + args = parser.parse_args() + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + if args.dataset_name is None and args.train_data_dir is None: + raise ValueError("You must specify either a dataset name from the hub or a train data directory.") + + return args + + +def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): + if token is None: + token = HfFolder.get_token() + if organization is None: + username = whoami(token)["name"] + return f"{username}/{model_id}" + else: + return f"{organization}/{model_id}" def main(args): @@ -59,7 +257,17 @@ def main(args): "UpBlock2D", ), ) - noise_scheduler = DDPMScheduler(num_train_timesteps=1000) + accepts_prediction_type = "prediction_type" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys()) + + if accepts_prediction_type: + noise_scheduler = DDPMScheduler( + num_train_timesteps=args.ddpm_num_steps, + beta_schedule=args.ddpm_beta_schedule, + prediction_type=args.prediction_type, + ) + else: + noise_scheduler = DDPMScheduler(num_train_timesteps=args.ddpm_num_steps, beta_schedule=args.ddpm_beta_schedule) + optimizer = torch.optim.AdamW( model.parameters(), lr=args.learning_rate, @@ -92,8 +300,12 @@ def transforms(examples): images = [augmentations(image.convert("RGB")) for image in examples["image"]] return {"input": images} + logger.info(f"Dataset size: {len(dataset)}") + dataset.set_transform(transforms) - train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.train_batch_size, shuffle=True) + train_dataloader = torch.utils.data.DataLoader( + dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers + ) lr_scheduler = get_scheduler( args.lr_scheduler, @@ -110,8 +322,22 @@ def transforms(examples): ema_model = EMAModel(model, inv_gamma=args.ema_inv_gamma, power=args.ema_power, max_value=args.ema_max_decay) - if args.push_to_hub: - repo = init_git_repo(args, at_init=True) + # Handle the repository creation + if accelerator.is_main_process: + if args.push_to_hub: + if args.hub_model_id is None: + repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) + else: + repo_name = args.hub_model_id + repo = Repository(args.output_dir, clone_from=repo_name) + + with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: + if "step_*" not in gitignore: + gitignore.write("step_*\n") + if "epoch_*" not in gitignore: + gitignore.write("epoch_*\n") + elif args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) if accelerator.is_main_process: run = os.path.split(__file__)[-1].split(".")[0] @@ -138,8 +364,22 @@ def transforms(examples): with accelerator.accumulate(model): # Predict the noise residual - noise_pred = model(noisy_images, timesteps).sample - loss = F.mse_loss(noise_pred, noise) + model_output = model(noisy_images, timesteps).sample + + if args.prediction_type == "epsilon": + loss = F.mse_loss(model_output, noise) # this could have different weights! + elif args.prediction_type == "sample": + alpha_t = _extract_into_tensor( + noise_scheduler.alphas_cumprod, timesteps, (clean_images.shape[0], 1, 1, 1) + ) + snr_weights = alpha_t / (1 - alpha_t) + loss = snr_weights * F.mse_loss( + model_output, clean_images, reduction="none" + ) # use SNR weighting from distillation paper + loss = loss.mean() + else: + raise ValueError(f"Unsupported prediction type: {args.prediction_type}") + accelerator.backward(loss) if accelerator.sync_gradients: @@ -172,9 +412,17 @@ def transforms(examples): scheduler=noise_scheduler, ) - generator = torch.manual_seed(0) + deprecate("todo: remove this check", "0.10.0", "when the most used version is >= 0.8.0") + if diffusers_version < version.parse("0.8.0"): + generator = torch.manual_seed(0) + else: + generator = torch.Generator(device=pipeline.device).manual_seed(0) # run pipeline in inference (sample random noise and denoise) - images = pipeline(generator=generator, batch_size=args.eval_batch_size, output_type="numpy").images + images = pipeline( + generator=generator, + batch_size=args.eval_batch_size, + output_type="numpy", + ).images # denormalize the images and save to tensorboard images_processed = (images * 255).round().astype("uint8") @@ -193,55 +441,5 @@ def transforms(examples): if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Simple example of a training script.") - parser.add_argument("--local_rank", type=int, default=-1) - parser.add_argument("--dataset_name", type=str, default=None) - parser.add_argument("--dataset_config_name", type=str, default=None) - parser.add_argument("--train_data_dir", type=str, default=None, help="A folder containing the training data.") - parser.add_argument("--output_dir", type=str, default="ddpm-model-64") - parser.add_argument("--overwrite_output_dir", action="store_true") - parser.add_argument("--cache_dir", type=str, default=None) - parser.add_argument("--resolution", type=int, default=64) - parser.add_argument("--train_batch_size", type=int, default=16) - parser.add_argument("--eval_batch_size", type=int, default=16) - parser.add_argument("--num_epochs", type=int, default=100) - parser.add_argument("--save_images_epochs", type=int, default=10) - parser.add_argument("--save_model_epochs", type=int, default=10) - parser.add_argument("--gradient_accumulation_steps", type=int, default=1) - parser.add_argument("--learning_rate", type=float, default=1e-4) - parser.add_argument("--lr_scheduler", type=str, default="cosine") - parser.add_argument("--lr_warmup_steps", type=int, default=500) - parser.add_argument("--adam_beta1", type=float, default=0.95) - parser.add_argument("--adam_beta2", type=float, default=0.999) - parser.add_argument("--adam_weight_decay", type=float, default=1e-6) - parser.add_argument("--adam_epsilon", type=float, default=1e-08) - parser.add_argument("--use_ema", action="store_true", default=True) - parser.add_argument("--ema_inv_gamma", type=float, default=1.0) - parser.add_argument("--ema_power", type=float, default=3 / 4) - parser.add_argument("--ema_max_decay", type=float, default=0.9999) - parser.add_argument("--push_to_hub", action="store_true") - parser.add_argument("--hub_token", type=str, default=None) - parser.add_argument("--hub_model_id", type=str, default=None) - parser.add_argument("--hub_private_repo", action="store_true") - parser.add_argument("--logging_dir", type=str, default="logs") - parser.add_argument( - "--mixed_precision", - type=str, - default="no", - choices=["no", "fp16", "bf16"], - help=( - "Whether to use mixed precision. Choose" - "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." - "and an Nvidia Ampere GPU." - ), - ) - - args = parser.parse_args() - env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) - if env_local_rank != -1 and env_local_rank != args.local_rank: - args.local_rank = env_local_rank - - if args.dataset_name is None and args.train_data_dir is None: - raise ValueError("You must specify either a dataset name from the hub or a train data directory.") - + args = parse_args() main(args) diff --git a/examples/unconditional_image_generation/train_unconditional_ort.py b/examples/unconditional_image_generation/train_unconditional_ort.py new file mode 100644 index 000000000000..8259c835fc3a --- /dev/null +++ b/examples/unconditional_image_generation/train_unconditional_ort.py @@ -0,0 +1,251 @@ +import argparse +import math +import os + +import torch +import torch.nn.functional as F + +from accelerate import Accelerator +from accelerate.logging import get_logger +from datasets import load_dataset +from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel +from diffusers.hub_utils import init_git_repo, push_to_hub +from diffusers.optimization import get_scheduler +from diffusers.training_utils import EMAModel +from onnxruntime.training.ortmodule import ORTModule +from torchvision.transforms import ( + CenterCrop, + Compose, + InterpolationMode, + Normalize, + RandomHorizontalFlip, + Resize, + ToTensor, +) +from tqdm.auto import tqdm + + +logger = get_logger(__name__) + + +def main(args): + logging_dir = os.path.join(args.output_dir, args.logging_dir) + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with="tensorboard", + logging_dir=logging_dir, + ) + + model = UNet2DModel( + sample_size=args.resolution, + in_channels=3, + out_channels=3, + layers_per_block=2, + block_out_channels=(128, 128, 256, 256, 512, 512), + down_block_types=( + "DownBlock2D", + "DownBlock2D", + "DownBlock2D", + "DownBlock2D", + "AttnDownBlock2D", + "DownBlock2D", + ), + up_block_types=( + "UpBlock2D", + "AttnUpBlock2D", + "UpBlock2D", + "UpBlock2D", + "UpBlock2D", + "UpBlock2D", + ), + ) + model = ORTModule(model) + noise_scheduler = DDPMScheduler(num_train_timesteps=1000, tensor_format="pt") + optimizer = torch.optim.AdamW( + model.parameters(), + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + augmentations = Compose( + [ + Resize(args.resolution, interpolation=InterpolationMode.BILINEAR), + CenterCrop(args.resolution), + RandomHorizontalFlip(), + ToTensor(), + Normalize([0.5], [0.5]), + ] + ) + + if args.dataset_name is not None: + dataset = load_dataset( + args.dataset_name, + args.dataset_config_name, + cache_dir=args.cache_dir, + use_auth_token=True if args.use_auth_token else None, + split="train", + ) + else: + dataset = load_dataset("imagefolder", data_dir=args.train_data_dir, cache_dir=args.cache_dir, split="train") + + def transforms(examples): + images = [augmentations(image.convert("RGB")) for image in examples["image"]] + return {"input": images} + + dataset.set_transform(transforms) + train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.train_batch_size, shuffle=True) + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps, + num_training_steps=(len(train_dataloader) * args.num_epochs) // args.gradient_accumulation_steps, + ) + + model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + model, optimizer, train_dataloader, lr_scheduler + ) + + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + + ema_model = EMAModel(model, inv_gamma=args.ema_inv_gamma, power=args.ema_power, max_value=args.ema_max_decay) + + if args.push_to_hub: + repo = init_git_repo(args, at_init=True) + + if accelerator.is_main_process: + run = os.path.split(__file__)[-1].split(".")[0] + accelerator.init_trackers(run) + + global_step = 0 + for epoch in range(args.num_epochs): + model.train() + progress_bar = tqdm(total=num_update_steps_per_epoch, disable=not accelerator.is_local_main_process) + progress_bar.set_description(f"Epoch {epoch}") + for step, batch in enumerate(train_dataloader): + clean_images = batch["input"] + # Sample noise that we'll add to the images + noise = torch.randn(clean_images.shape).to(clean_images.device) + bsz = clean_images.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint( + 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=clean_images.device + ).long() + + # Add noise to the clean images according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps) + + with accelerator.accumulate(model): + # Predict the noise residual + noise_pred = model(noisy_images, timesteps, return_dict=True)[0] + loss = F.mse_loss(noise_pred, noise) + accelerator.backward(loss) + + accelerator.clip_grad_norm_(model.parameters(), 1.0) + optimizer.step() + lr_scheduler.step() + if args.use_ema: + ema_model.step(model) + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step} + if args.use_ema: + logs["ema_decay"] = ema_model.decay + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + progress_bar.close() + + accelerator.wait_for_everyone() + + # Generate sample images for visual inspection + if accelerator.is_main_process: + if epoch % args.save_images_epochs == 0 or epoch == args.num_epochs - 1: + pipeline = DDPMPipeline( + unet=accelerator.unwrap_model(ema_model.averaged_model if args.use_ema else model), + scheduler=noise_scheduler, + ) + + generator = torch.manual_seed(0) + # run pipeline in inference (sample random noise and denoise) + images = pipeline(generator=generator, batch_size=args.eval_batch_size, output_type="numpy").images + + # denormalize the images and save to tensorboard + images_processed = (images * 255).round().astype("uint8") + accelerator.trackers[0].writer.add_images( + "test_samples", images_processed.transpose(0, 3, 1, 2), epoch + ) + + if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1: + # save the model + if args.push_to_hub: + push_to_hub(args, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=False) + else: + pipeline.save_pretrained(args.output_dir) + accelerator.wait_for_everyone() + + accelerator.end_training() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument("--local_rank", type=int, default=-1) + parser.add_argument("--dataset_name", type=str, default=None) + parser.add_argument("--dataset_config_name", type=str, default=None) + parser.add_argument("--train_data_dir", type=str, default=None, help="A folder containing the training data.") + parser.add_argument("--output_dir", type=str, default="ddpm-model-64") + parser.add_argument("--overwrite_output_dir", action="store_true") + parser.add_argument("--cache_dir", type=str, default=None) + parser.add_argument("--resolution", type=int, default=64) + parser.add_argument("--train_batch_size", type=int, default=16) + parser.add_argument("--eval_batch_size", type=int, default=16) + parser.add_argument("--num_epochs", type=int, default=100) + parser.add_argument("--save_images_epochs", type=int, default=10) + parser.add_argument("--save_model_epochs", type=int, default=10) + parser.add_argument("--gradient_accumulation_steps", type=int, default=1) + parser.add_argument("--learning_rate", type=float, default=1e-4) + parser.add_argument("--lr_scheduler", type=str, default="cosine") + parser.add_argument("--lr_warmup_steps", type=int, default=500) + parser.add_argument("--adam_beta1", type=float, default=0.95) + parser.add_argument("--adam_beta2", type=float, default=0.999) + parser.add_argument("--adam_weight_decay", type=float, default=1e-6) + parser.add_argument("--adam_epsilon", type=float, default=1e-08) + parser.add_argument("--use_ema", action="store_true", default=True) + parser.add_argument("--ema_inv_gamma", type=float, default=1.0) + parser.add_argument("--ema_power", type=float, default=3 / 4) + parser.add_argument("--ema_max_decay", type=float, default=0.9999) + parser.add_argument("--push_to_hub", action="store_true") + parser.add_argument("--use_auth_token", action="store_true") + parser.add_argument("--hub_token", type=str, default=None) + parser.add_argument("--hub_model_id", type=str, default=None) + parser.add_argument("--hub_private_repo", action="store_true") + parser.add_argument("--logging_dir", type=str, default="logs") + parser.add_argument( + "--mixed_precision", + type=str, + default="no", + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose" + "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." + "and an Nvidia Ampere GPU." + ), + ) + + args = parser.parse_args() + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + if args.dataset_name is None and args.train_data_dir is None: + raise ValueError("You must specify either a dataset name from the hub or a train data directory.") + + main(args) diff --git a/scripts/convert_dance_diffusion_to_diffusers.py b/scripts/convert_dance_diffusion_to_diffusers.py new file mode 100755 index 000000000000..db183bce9262 --- /dev/null +++ b/scripts/convert_dance_diffusion_to_diffusers.py @@ -0,0 +1,339 @@ +#!/usr/bin/env python3 +import argparse +import math +import os +from copy import deepcopy + +import torch +from torch import nn + +from audio_diffusion.models import DiffusionAttnUnet1D +from diffusers import DanceDiffusionPipeline, IPNDMScheduler, UNet1DModel +from diffusion import sampling + + +MODELS_MAP = { + "gwf-440k": { + "url": "https://model-server.zqevans2.workers.dev/gwf-440k.ckpt", + "sample_rate": 48000, + "sample_size": 65536, + }, + "jmann-small-190k": { + "url": "https://model-server.zqevans2.workers.dev/jmann-small-190k.ckpt", + "sample_rate": 48000, + "sample_size": 65536, + }, + "jmann-large-580k": { + "url": "https://model-server.zqevans2.workers.dev/jmann-large-580k.ckpt", + "sample_rate": 48000, + "sample_size": 131072, + }, + "maestro-uncond-150k": { + "url": "https://model-server.zqevans2.workers.dev/maestro-uncond-150k.ckpt", + "sample_rate": 16000, + "sample_size": 65536, + }, + "unlocked-uncond-250k": { + "url": "https://model-server.zqevans2.workers.dev/unlocked-uncond-250k.ckpt", + "sample_rate": 16000, + "sample_size": 65536, + }, + "honk-140k": { + "url": "https://model-server.zqevans2.workers.dev/honk-140k.ckpt", + "sample_rate": 16000, + "sample_size": 65536, + }, +} + + +def alpha_sigma_to_t(alpha, sigma): + """Returns a timestep, given the scaling factors for the clean image and for + the noise.""" + return torch.atan2(sigma, alpha) / math.pi * 2 + + +def get_crash_schedule(t): + sigma = torch.sin(t * math.pi / 2) ** 2 + alpha = (1 - sigma**2) ** 0.5 + return alpha_sigma_to_t(alpha, sigma) + + +class Object(object): + pass + + +class DiffusionUncond(nn.Module): + def __init__(self, global_args): + super().__init__() + + self.diffusion = DiffusionAttnUnet1D(global_args, n_attn_layers=4) + self.diffusion_ema = deepcopy(self.diffusion) + self.rng = torch.quasirandom.SobolEngine(1, scramble=True) + + +def download(model_name): + url = MODELS_MAP[model_name]["url"] + os.system(f"wget {url} ./") + + return f"./{model_name}.ckpt" + + +DOWN_NUM_TO_LAYER = { + "1": "resnets.0", + "2": "attentions.0", + "3": "resnets.1", + "4": "attentions.1", + "5": "resnets.2", + "6": "attentions.2", +} +UP_NUM_TO_LAYER = { + "8": "resnets.0", + "9": "attentions.0", + "10": "resnets.1", + "11": "attentions.1", + "12": "resnets.2", + "13": "attentions.2", +} +MID_NUM_TO_LAYER = { + "1": "resnets.0", + "2": "attentions.0", + "3": "resnets.1", + "4": "attentions.1", + "5": "resnets.2", + "6": "attentions.2", + "8": "resnets.3", + "9": "attentions.3", + "10": "resnets.4", + "11": "attentions.4", + "12": "resnets.5", + "13": "attentions.5", +} +DEPTH_0_TO_LAYER = { + "0": "resnets.0", + "1": "resnets.1", + "2": "resnets.2", + "4": "resnets.0", + "5": "resnets.1", + "6": "resnets.2", +} + +RES_CONV_MAP = { + "skip": "conv_skip", + "main.0": "conv_1", + "main.1": "group_norm_1", + "main.3": "conv_2", + "main.4": "group_norm_2", +} + +ATTN_MAP = { + "norm": "group_norm", + "qkv_proj": ["query", "key", "value"], + "out_proj": ["proj_attn"], +} + + +def convert_resconv_naming(name): + if name.startswith("skip"): + return name.replace("skip", RES_CONV_MAP["skip"]) + + # name has to be of format main.{digit} + if not name.startswith("main."): + raise ValueError(f"ResConvBlock error with {name}") + + return name.replace(name[:6], RES_CONV_MAP[name[:6]]) + + +def convert_attn_naming(name): + for key, value in ATTN_MAP.items(): + if name.startswith(key) and not isinstance(value, list): + return name.replace(key, value) + elif name.startswith(key): + return [name.replace(key, v) for v in value] + raise ValueError(f"Attn error with {name}") + + +def rename(input_string, max_depth=13): + string = input_string + + if string.split(".")[0] == "timestep_embed": + return string.replace("timestep_embed", "time_proj") + + depth = 0 + if string.startswith("net.3."): + depth += 1 + string = string[6:] + elif string.startswith("net."): + string = string[4:] + + while string.startswith("main.7."): + depth += 1 + string = string[7:] + + if string.startswith("main."): + string = string[5:] + + # mid block + if string[:2].isdigit(): + layer_num = string[:2] + string_left = string[2:] + else: + layer_num = string[0] + string_left = string[1:] + + if depth == max_depth: + new_layer = MID_NUM_TO_LAYER[layer_num] + prefix = "mid_block" + elif depth > 0 and int(layer_num) < 7: + new_layer = DOWN_NUM_TO_LAYER[layer_num] + prefix = f"down_blocks.{depth}" + elif depth > 0 and int(layer_num) > 7: + new_layer = UP_NUM_TO_LAYER[layer_num] + prefix = f"up_blocks.{max_depth - depth - 1}" + elif depth == 0: + new_layer = DEPTH_0_TO_LAYER[layer_num] + prefix = f"up_blocks.{max_depth - 1}" if int(layer_num) > 3 else "down_blocks.0" + + if not string_left.startswith("."): + raise ValueError(f"Naming error with {input_string} and string_left: {string_left}.") + + string_left = string_left[1:] + + if "resnets" in new_layer: + string_left = convert_resconv_naming(string_left) + elif "attentions" in new_layer: + new_string_left = convert_attn_naming(string_left) + string_left = new_string_left + + if not isinstance(string_left, list): + new_string = prefix + "." + new_layer + "." + string_left + else: + new_string = [prefix + "." + new_layer + "." + s for s in string_left] + return new_string + + +def rename_orig_weights(state_dict): + new_state_dict = {} + for k, v in state_dict.items(): + if k.endswith("kernel"): + # up- and downsample layers, don't have trainable weights + continue + + new_k = rename(k) + + # check if we need to transform from Conv => Linear for attention + if isinstance(new_k, list): + new_state_dict = transform_conv_attns(new_state_dict, new_k, v) + else: + new_state_dict[new_k] = v + + return new_state_dict + + +def transform_conv_attns(new_state_dict, new_k, v): + if len(new_k) == 1: + if len(v.shape) == 3: + # weight + new_state_dict[new_k[0]] = v[:, :, 0] + else: + # bias + new_state_dict[new_k[0]] = v + else: + # qkv matrices + trippled_shape = v.shape[0] + single_shape = trippled_shape // 3 + for i in range(3): + if len(v.shape) == 3: + new_state_dict[new_k[i]] = v[i * single_shape : (i + 1) * single_shape, :, 0] + else: + new_state_dict[new_k[i]] = v[i * single_shape : (i + 1) * single_shape] + return new_state_dict + + +def main(args): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + model_name = args.model_path.split("/")[-1].split(".")[0] + if not os.path.isfile(args.model_path): + assert ( + model_name == args.model_path + ), f"Make sure to provide one of the official model names {MODELS_MAP.keys()}" + args.model_path = download(model_name) + + sample_rate = MODELS_MAP[model_name]["sample_rate"] + sample_size = MODELS_MAP[model_name]["sample_size"] + + config = Object() + config.sample_size = sample_size + config.sample_rate = sample_rate + config.latent_dim = 0 + + diffusers_model = UNet1DModel(sample_size=sample_size, sample_rate=sample_rate) + diffusers_state_dict = diffusers_model.state_dict() + + orig_model = DiffusionUncond(config) + orig_model.load_state_dict(torch.load(args.model_path, map_location=device)["state_dict"]) + orig_model = orig_model.diffusion_ema.eval() + orig_model_state_dict = orig_model.state_dict() + renamed_state_dict = rename_orig_weights(orig_model_state_dict) + + renamed_minus_diffusers = set(renamed_state_dict.keys()) - set(diffusers_state_dict.keys()) + diffusers_minus_renamed = set(diffusers_state_dict.keys()) - set(renamed_state_dict.keys()) + + assert len(renamed_minus_diffusers) == 0, f"Problem with {renamed_minus_diffusers}" + assert all(k.endswith("kernel") for k in list(diffusers_minus_renamed)), f"Problem with {diffusers_minus_renamed}" + + for key, value in renamed_state_dict.items(): + assert ( + diffusers_state_dict[key].squeeze().shape == value.squeeze().shape + ), f"Shape for {key} doesn't match. Diffusers: {diffusers_state_dict[key].shape} vs. {value.shape}" + if key == "time_proj.weight": + value = value.squeeze() + + diffusers_state_dict[key] = value + + diffusers_model.load_state_dict(diffusers_state_dict) + + steps = 100 + seed = 33 + + diffusers_scheduler = IPNDMScheduler(num_train_timesteps=steps) + + generator = torch.manual_seed(seed) + noise = torch.randn([1, 2, config.sample_size], generator=generator).to(device) + + t = torch.linspace(1, 0, steps + 1, device=device)[:-1] + step_list = get_crash_schedule(t) + + pipe = DanceDiffusionPipeline(unet=diffusers_model, scheduler=diffusers_scheduler) + + generator = torch.manual_seed(33) + audio = pipe(num_inference_steps=steps, generator=generator).audios + + generated = sampling.iplms_sample(orig_model, noise, step_list, {}) + generated = generated.clamp(-1, 1) + + diff_sum = (generated - audio).abs().sum() + diff_max = (generated - audio).abs().max() + + if args.save: + pipe.save_pretrained(args.checkpoint_path) + + print("Diff sum", diff_sum) + print("Diff max", diff_max) + + assert diff_max < 1e-3, f"Diff max: {diff_max} is too much :-/" + + print(f"Conversion for {model_name} successful!") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.") + parser.add_argument( + "--save", default=True, type=bool, required=False, help="Whether to save the converted model or not." + ) + parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.") + args = parser.parse_args() + + main(args) diff --git a/scripts/convert_ldm_original_checkpoint_to_diffusers.py b/scripts/convert_ldm_original_checkpoint_to_diffusers.py index 52865792256f..f547e96f4ed5 100644 --- a/scripts/convert_ldm_original_checkpoint_to_diffusers.py +++ b/scripts/convert_ldm_original_checkpoint_to_diffusers.py @@ -112,9 +112,9 @@ def assign_to_checkpoint( continue # Global renaming happens here - new_path = new_path.replace("middle_block.0", "mid.resnets.0") - new_path = new_path.replace("middle_block.1", "mid.attentions.0") - new_path = new_path.replace("middle_block.2", "mid.resnets.1") + new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") + new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") + new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") if additional_replacements is not None: for replacement in additional_replacements: @@ -175,15 +175,16 @@ def convert_ldm_checkpoint(checkpoint, config): attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] if f"input_blocks.{i}.0.op.weight" in checkpoint: - new_checkpoint[f"downsample_blocks.{block_id}.downsamplers.0.conv.weight"] = checkpoint[ + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = checkpoint[ f"input_blocks.{i}.0.op.weight" ] - new_checkpoint[f"downsample_blocks.{block_id}.downsamplers.0.conv.bias"] = checkpoint[ + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = checkpoint[ f"input_blocks.{i}.0.op.bias" ] + continue paths = renew_resnet_paths(resnets) - meta_path = {"old": f"input_blocks.{i}.0", "new": f"downsample_blocks.{block_id}.resnets.{layer_in_block_id}"} + meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} resnet_op = {"old": "resnets.2.op", "new": "downsamplers.0.op"} assign_to_checkpoint( paths, new_checkpoint, checkpoint, additional_replacements=[meta_path, resnet_op], config=config @@ -193,18 +194,18 @@ def convert_ldm_checkpoint(checkpoint, config): paths = renew_attention_paths(attentions) meta_path = { "old": f"input_blocks.{i}.1", - "new": f"downsample_blocks.{block_id}.attentions.{layer_in_block_id}", + "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}", } to_split = { f"input_blocks.{i}.1.qkv.bias": { - "key": f"downsample_blocks.{block_id}.attentions.{layer_in_block_id}.key.bias", - "query": f"downsample_blocks.{block_id}.attentions.{layer_in_block_id}.query.bias", - "value": f"downsample_blocks.{block_id}.attentions.{layer_in_block_id}.value.bias", + "key": f"down_blocks.{block_id}.attentions.{layer_in_block_id}.key.bias", + "query": f"down_blocks.{block_id}.attentions.{layer_in_block_id}.query.bias", + "value": f"down_blocks.{block_id}.attentions.{layer_in_block_id}.value.bias", }, f"input_blocks.{i}.1.qkv.weight": { - "key": f"downsample_blocks.{block_id}.attentions.{layer_in_block_id}.key.weight", - "query": f"downsample_blocks.{block_id}.attentions.{layer_in_block_id}.query.weight", - "value": f"downsample_blocks.{block_id}.attentions.{layer_in_block_id}.value.weight", + "key": f"down_blocks.{block_id}.attentions.{layer_in_block_id}.key.weight", + "query": f"down_blocks.{block_id}.attentions.{layer_in_block_id}.query.weight", + "value": f"down_blocks.{block_id}.attentions.{layer_in_block_id}.value.weight", }, } assign_to_checkpoint( diff --git a/scripts/convert_models_diffuser_to_diffusers.py b/scripts/convert_models_diffuser_to_diffusers.py new file mode 100644 index 000000000000..9475f7da93fb --- /dev/null +++ b/scripts/convert_models_diffuser_to_diffusers.py @@ -0,0 +1,100 @@ +import json +import os + +import torch + +from diffusers import UNet1DModel + + +os.makedirs("hub/hopper-medium-v2/unet/hor32", exist_ok=True) +os.makedirs("hub/hopper-medium-v2/unet/hor128", exist_ok=True) + +os.makedirs("hub/hopper-medium-v2/value_function", exist_ok=True) + + +def unet(hor): + if hor == 128: + down_block_types = ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D") + block_out_channels = (32, 128, 256) + up_block_types = ("UpResnetBlock1D", "UpResnetBlock1D") + + elif hor == 32: + down_block_types = ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D") + block_out_channels = (32, 64, 128, 256) + up_block_types = ("UpResnetBlock1D", "UpResnetBlock1D", "UpResnetBlock1D") + model = torch.load(f"/Users/bglickenhaus/Documents/diffuser/temporal_unet-hopper-mediumv2-hor{hor}.torch") + state_dict = model.state_dict() + config = dict( + down_block_types=down_block_types, + block_out_channels=block_out_channels, + up_block_types=up_block_types, + layers_per_block=1, + use_timestep_embedding=True, + out_block_type="OutConv1DBlock", + norm_num_groups=8, + downsample_each_block=False, + in_channels=14, + out_channels=14, + extra_in_channels=0, + time_embedding_type="positional", + flip_sin_to_cos=False, + freq_shift=1, + sample_size=65536, + mid_block_type="MidResTemporalBlock1D", + act_fn="mish", + ) + hf_value_function = UNet1DModel(**config) + print(f"length of state dict: {len(state_dict.keys())}") + print(f"length of value function dict: {len(hf_value_function.state_dict().keys())}") + mapping = dict((k, hfk) for k, hfk in zip(model.state_dict().keys(), hf_value_function.state_dict().keys())) + for k, v in mapping.items(): + state_dict[v] = state_dict.pop(k) + hf_value_function.load_state_dict(state_dict) + + torch.save(hf_value_function.state_dict(), f"hub/hopper-medium-v2/unet/hor{hor}/diffusion_pytorch_model.bin") + with open(f"hub/hopper-medium-v2/unet/hor{hor}/config.json", "w") as f: + json.dump(config, f) + + +def value_function(): + config = dict( + in_channels=14, + down_block_types=("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"), + up_block_types=(), + out_block_type="ValueFunction", + mid_block_type="ValueFunctionMidBlock1D", + block_out_channels=(32, 64, 128, 256), + layers_per_block=1, + downsample_each_block=True, + sample_size=65536, + out_channels=14, + extra_in_channels=0, + time_embedding_type="positional", + use_timestep_embedding=True, + flip_sin_to_cos=False, + freq_shift=1, + norm_num_groups=8, + act_fn="mish", + ) + + model = torch.load("/Users/bglickenhaus/Documents/diffuser/value_function-hopper-mediumv2-hor32.torch") + state_dict = model + hf_value_function = UNet1DModel(**config) + print(f"length of state dict: {len(state_dict.keys())}") + print(f"length of value function dict: {len(hf_value_function.state_dict().keys())}") + + mapping = dict((k, hfk) for k, hfk in zip(state_dict.keys(), hf_value_function.state_dict().keys())) + for k, v in mapping.items(): + state_dict[v] = state_dict.pop(k) + + hf_value_function.load_state_dict(state_dict) + + torch.save(hf_value_function.state_dict(), "hub/hopper-medium-v2/value_function/diffusion_pytorch_model.bin") + with open("hub/hopper-medium-v2/value_function/config.json", "w") as f: + json.dump(config, f) + + +if __name__ == "__main__": + unet(32) + # unet(128) + value_function() diff --git a/scripts/convert_original_stable_diffusion_to_diffusers.py b/scripts/convert_original_stable_diffusion_to_diffusers.py index db1b30736984..2d354df93818 100644 --- a/scripts/convert_original_stable_diffusion_to_diffusers.py +++ b/scripts/convert_original_stable_diffusion_to_diffusers.py @@ -30,6 +30,9 @@ from diffusers import ( AutoencoderKL, DDIMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, LDMTextToImagePipeline, LMSDiscreteScheduler, PNDMScheduler, @@ -208,6 +211,7 @@ def create_unet_diffusers_config(original_config): """ Creates a config for the diffusers based on the config of the LDM model. """ + model_params = original_config.model.params unet_params = original_config.model.params.unet_config.params block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult] @@ -227,7 +231,7 @@ def create_unet_diffusers_config(original_config): resolution //= 2 config = dict( - sample_size=unet_params.image_size, + sample_size=model_params.image_size, in_channels=unet_params.in_channels, out_channels=unet_params.out_channels, down_block_types=tuple(down_block_types), @@ -285,15 +289,34 @@ def create_ldm_bert_config(original_config): return config -def convert_ldm_unet_checkpoint(checkpoint, config): +def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False): """ Takes a state dict and a config, and returns a converted checkpoint. """ # extract state_dict for UNet unet_state_dict = {} - unet_key = "model.diffusion_model." keys = list(checkpoint.keys()) + + unet_key = "model.diffusion_model." + # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA + if sum(k.startswith("model_ema") for k in keys) > 100: + print(f"Checkpoint {path} has both EMA and non-EMA weights.") + if extract_ema: + print( + "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA" + " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag." + ) + for key in keys: + if key.startswith("model.diffusion_model"): + flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) + else: + print( + "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA" + " weights (usually better for inference), please make sure to add the `--extract_ema` flag." + ) + for key in keys: if key.startswith(unet_key): unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) @@ -628,7 +651,16 @@ def convert_ldm_clip_checkpoint(checkpoint): "--scheduler_type", default="pndm", type=str, - help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim']", + help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim', 'euler', 'euler-ancest', 'dpm']", + ) + parser.add_argument( + "--extract_ema", + action="store_true", + help=( + "Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights" + " or not. Defaults to `False`. Add `--extract_ema` to extract the EMA weights. EMA weights usually yield" + " higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning." + ), ) parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") @@ -641,7 +673,9 @@ def convert_ldm_clip_checkpoint(checkpoint): args.original_config_file = "./v1-inference.yaml" original_config = OmegaConf.load(args.original_config_file) - checkpoint = torch.load(args.checkpoint_path)["state_dict"] + + checkpoint = torch.load(args.checkpoint_path) + checkpoint = checkpoint["state_dict"] num_train_timesteps = original_config.model.params.timesteps beta_start = original_config.model.params.linear_start @@ -656,6 +690,16 @@ def convert_ldm_clip_checkpoint(checkpoint): ) elif args.scheduler_type == "lms": scheduler = LMSDiscreteScheduler(beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear") + elif args.scheduler_type == "euler": + scheduler = EulerDiscreteScheduler(beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear") + elif args.scheduler_type == "euler-ancestral": + scheduler = EulerAncestralDiscreteScheduler( + beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear" + ) + elif args.scheduler_type == "dpm": + scheduler = DPMSolverMultistepScheduler( + beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear" + ) elif args.scheduler_type == "ddim": scheduler = DDIMScheduler( beta_start=beta_start, @@ -669,7 +713,9 @@ def convert_ldm_clip_checkpoint(checkpoint): # Convert the UNet2DConditionModel model. unet_config = create_unet_diffusers_config(original_config) - converted_unet_checkpoint = convert_ldm_unet_checkpoint(checkpoint, unet_config) + converted_unet_checkpoint = convert_ldm_unet_checkpoint( + checkpoint, unet_config, path=args.checkpoint_path, extract_ema=args.extract_ema + ) unet = UNet2DConditionModel(**unet_config) unet.load_state_dict(converted_unet_checkpoint) diff --git a/scripts/convert_stable_diffusion_checkpoint_to_onnx.py b/scripts/convert_stable_diffusion_checkpoint_to_onnx.py index a388c0d078dc..26d3d5618f88 100644 --- a/scripts/convert_stable_diffusion_checkpoint_to_onnx.py +++ b/scripts/convert_stable_diffusion_checkpoint_to_onnx.py @@ -21,7 +21,7 @@ from torch.onnx import export import onnx -from diffusers import StableDiffusionOnnxPipeline, StableDiffusionPipeline +from diffusers import OnnxStableDiffusionPipeline, StableDiffusionPipeline from diffusers.onnx_utils import OnnxRuntimeModel from packaging import version @@ -69,11 +69,20 @@ def onnx_export( @torch.no_grad() -def convert_models(model_path: str, output_path: str, opset: int): - pipeline = StableDiffusionPipeline.from_pretrained(model_path) +def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = False): + dtype = torch.float16 if fp16 else torch.float32 + if fp16 and torch.cuda.is_available(): + device = "cuda" + elif fp16 and not torch.cuda.is_available(): + raise ValueError("`float16` model export is only supported on GPUs with CUDA") + else: + device = "cpu" + pipeline = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=dtype).to(device) output_path = Path(output_path) # TEXT ENCODER + num_tokens = pipeline.text_encoder.config.max_position_embeddings + text_hidden_size = pipeline.text_encoder.config.hidden_size text_input = pipeline.tokenizer( "A sample prompt", padding="max_length", @@ -84,7 +93,7 @@ def convert_models(model_path: str, output_path: str, opset: int): onnx_export( pipeline.text_encoder, # casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files - model_args=(text_input.input_ids.to(torch.int32)), + model_args=(text_input.input_ids.to(device=device, dtype=torch.int32)), output_path=output_path / "text_encoder" / "model.onnx", ordered_input_names=["input_ids"], output_names=["last_hidden_state", "pooler_output"], @@ -93,12 +102,20 @@ def convert_models(model_path: str, output_path: str, opset: int): }, opset=opset, ) + del pipeline.text_encoder # UNET + unet_in_channels = pipeline.unet.config.in_channels + unet_sample_size = pipeline.unet.config.sample_size unet_path = output_path / "unet" / "model.onnx" onnx_export( pipeline.unet, - model_args=(torch.randn(2, 4, 64, 64), torch.LongTensor([0, 1]), torch.randn(2, 77, 768), False), + model_args=( + torch.randn(2, unet_in_channels, unet_sample_size, unet_sample_size).to(device=device, dtype=dtype), + torch.randn(2).to(device=device, dtype=dtype), + torch.randn(2, num_tokens, text_hidden_size).to(device=device, dtype=dtype), + False, + ), output_path=unet_path, ordered_input_names=["sample", "timestep", "encoder_hidden_states", "return_dict"], output_names=["out_sample"], # has to be different from "sample" for correct tracing @@ -125,14 +142,20 @@ def convert_models(model_path: str, output_path: str, opset: int): location="weights.pb", convert_attribute=False, ) + del pipeline.unet # VAE ENCODER vae_encoder = pipeline.vae + vae_in_channels = vae_encoder.config.in_channels + vae_sample_size = vae_encoder.config.sample_size # need to get the raw tensor output (sample) from the encoder vae_encoder.forward = lambda sample, return_dict: vae_encoder.encode(sample, return_dict)[0].sample() onnx_export( vae_encoder, - model_args=(torch.randn(1, 3, 512, 512), False), + model_args=( + torch.randn(1, vae_in_channels, vae_sample_size, vae_sample_size).to(device=device, dtype=dtype), + False, + ), output_path=output_path / "vae_encoder" / "model.onnx", ordered_input_names=["sample", "return_dict"], output_names=["latent_sample"], @@ -144,11 +167,16 @@ def convert_models(model_path: str, output_path: str, opset: int): # VAE DECODER vae_decoder = pipeline.vae + vae_latent_channels = vae_decoder.config.latent_channels + vae_out_channels = vae_decoder.config.out_channels # forward only through the decoder part vae_decoder.forward = vae_encoder.decode onnx_export( vae_decoder, - model_args=(torch.randn(1, 4, 64, 64), False), + model_args=( + torch.randn(1, vae_latent_channels, unet_sample_size, unet_sample_size).to(device=device, dtype=dtype), + False, + ), output_path=output_path / "vae_decoder" / "model.onnx", ordered_input_names=["latent_sample", "return_dict"], output_names=["sample"], @@ -157,37 +185,59 @@ def convert_models(model_path: str, output_path: str, opset: int): }, opset=opset, ) + del pipeline.vae # SAFETY CHECKER - safety_checker = pipeline.safety_checker - safety_checker.forward = safety_checker.forward_onnx - onnx_export( - pipeline.safety_checker, - model_args=(torch.randn(1, 3, 224, 224), torch.randn(1, 512, 512, 3)), - output_path=output_path / "safety_checker" / "model.onnx", - ordered_input_names=["clip_input", "images"], - output_names=["out_images", "has_nsfw_concepts"], - dynamic_axes={ - "clip_input": {0: "batch", 1: "channels", 2: "height", 3: "width"}, - "images": {0: "batch", 1: "channels", 2: "height", 3: "width"}, - }, - opset=opset, - ) + if pipeline.safety_checker is not None: + safety_checker = pipeline.safety_checker + clip_num_channels = safety_checker.config.vision_config.num_channels + clip_image_size = safety_checker.config.vision_config.image_size + safety_checker.forward = safety_checker.forward_onnx + onnx_export( + pipeline.safety_checker, + model_args=( + torch.randn( + 1, + clip_num_channels, + clip_image_size, + clip_image_size, + ).to(device=device, dtype=dtype), + torch.randn(1, vae_sample_size, vae_sample_size, vae_out_channels).to(device=device, dtype=dtype), + ), + output_path=output_path / "safety_checker" / "model.onnx", + ordered_input_names=["clip_input", "images"], + output_names=["out_images", "has_nsfw_concepts"], + dynamic_axes={ + "clip_input": {0: "batch", 1: "channels", 2: "height", 3: "width"}, + "images": {0: "batch", 1: "height", 2: "width", 3: "channels"}, + }, + opset=opset, + ) + del pipeline.safety_checker + safety_checker = OnnxRuntimeModel.from_pretrained(output_path / "safety_checker") + feature_extractor = pipeline.feature_extractor + else: + safety_checker = None + feature_extractor = None - onnx_pipeline = StableDiffusionOnnxPipeline( + onnx_pipeline = OnnxStableDiffusionPipeline( + vae_encoder=OnnxRuntimeModel.from_pretrained(output_path / "vae_encoder"), vae_decoder=OnnxRuntimeModel.from_pretrained(output_path / "vae_decoder"), text_encoder=OnnxRuntimeModel.from_pretrained(output_path / "text_encoder"), tokenizer=pipeline.tokenizer, unet=OnnxRuntimeModel.from_pretrained(output_path / "unet"), scheduler=pipeline.scheduler, - safety_checker=OnnxRuntimeModel.from_pretrained(output_path / "safety_checker"), - feature_extractor=pipeline.feature_extractor, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + requires_safety_checker=safety_checker is not None, ) onnx_pipeline.save_pretrained(output_path) print("ONNX pipeline saved to", output_path) - _ = StableDiffusionOnnxPipeline.from_pretrained(output_path, provider="CPUExecutionProvider") + del pipeline + del onnx_pipeline + _ = OnnxStableDiffusionPipeline.from_pretrained(output_path, provider="CPUExecutionProvider") print("ONNX pipeline is loadable") @@ -209,7 +259,8 @@ def convert_models(model_path: str, output_path: str, opset: int): type=int, help="The version of the ONNX operator set to use.", ) + parser.add_argument("--fp16", action="store_true", default=False, help="Export the models in `float16` mode") args = parser.parse_args() - convert_models(args.model_path, args.output_path, args.opset) + convert_models(args.model_path, args.output_path, args.opset, args.fp16) diff --git a/scripts/convert_versatile_diffusion_to_diffusers.py b/scripts/convert_versatile_diffusion_to_diffusers.py new file mode 100644 index 000000000000..86fb0e7b4c97 --- /dev/null +++ b/scripts/convert_versatile_diffusion_to_diffusers.py @@ -0,0 +1,791 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Conversion script for the Versatile Stable Diffusion checkpoints. """ + +import argparse +from argparse import Namespace + +import torch + +from diffusers import ( + AutoencoderKL, + DDIMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, + UNet2DConditionModel, + VersatileDiffusionPipeline, +) +from diffusers.pipelines.versatile_diffusion.modeling_text_unet import UNetFlatConditionModel +from transformers import ( + CLIPFeatureExtractor, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + + +SCHEDULER_CONFIG = Namespace( + **{ + "beta_linear_start": 0.00085, + "beta_linear_end": 0.012, + "timesteps": 1000, + "scale_factor": 0.18215, + } +) + +IMAGE_UNET_CONFIG = Namespace( + **{ + "input_channels": 4, + "model_channels": 320, + "output_channels": 4, + "num_noattn_blocks": [2, 2, 2, 2], + "channel_mult": [1, 2, 4, 4], + "with_attn": [True, True, True, False], + "num_heads": 8, + "context_dim": 768, + "use_checkpoint": True, + } +) + +TEXT_UNET_CONFIG = Namespace( + **{ + "input_channels": 768, + "model_channels": 320, + "output_channels": 768, + "num_noattn_blocks": [2, 2, 2, 2], + "channel_mult": [1, 2, 4, 4], + "second_dim": [4, 4, 4, 4], + "with_attn": [True, True, True, False], + "num_heads": 8, + "context_dim": 768, + "use_checkpoint": True, + } +) + +AUTOENCODER_CONFIG = Namespace( + **{ + "double_z": True, + "z_channels": 4, + "resolution": 256, + "in_channels": 3, + "out_ch": 3, + "ch": 128, + "ch_mult": [1, 2, 4, 4], + "num_res_blocks": 2, + "attn_resolutions": [], + "dropout": 0.0, + } +) + + +def shave_segments(path, n_shave_prefix_segments=1): + """ + Removes segments. Positive values shave the first segments, negative shave the last segments. + """ + if n_shave_prefix_segments >= 0: + return ".".join(path.split(".")[n_shave_prefix_segments:]) + else: + return ".".join(path.split(".")[:n_shave_prefix_segments]) + + +def renew_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item.replace("in_layers.0", "norm1") + new_item = new_item.replace("in_layers.2", "conv1") + + new_item = new_item.replace("out_layers.0", "norm2") + new_item = new_item.replace("out_layers.3", "conv2") + + new_item = new_item.replace("emb_layers.1", "time_emb_proj") + new_item = new_item.replace("skip_connection", "conv_shortcut") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace("nin_shortcut", "conv_shortcut") + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + # new_item = new_item.replace('norm.weight', 'group_norm.weight') + # new_item = new_item.replace('norm.bias', 'group_norm.bias') + + # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') + # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') + + # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace("norm.weight", "group_norm.weight") + new_item = new_item.replace("norm.bias", "group_norm.bias") + + new_item = new_item.replace("q.weight", "query.weight") + new_item = new_item.replace("q.bias", "query.bias") + + new_item = new_item.replace("k.weight", "key.weight") + new_item = new_item.replace("k.bias", "key.bias") + + new_item = new_item.replace("v.weight", "value.weight") + new_item = new_item.replace("v.bias", "value.bias") + + new_item = new_item.replace("proj_out.weight", "proj_attn.weight") + new_item = new_item.replace("proj_out.bias", "proj_attn.bias") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def assign_to_checkpoint( + paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None +): + """ + This does the final conversion step: take locally converted weights and apply a global renaming + to them. It splits attention layers, and takes into account additional replacements + that may arise. + + Assigns the weights to the new checkpoint. + """ + assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." + + # Splits the attention layers into three variables. + if attention_paths_to_split is not None: + for path, path_map in attention_paths_to_split.items(): + old_tensor = old_checkpoint[path] + channels = old_tensor.shape[0] // 3 + + target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) + + num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 + + old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) + query, key, value = old_tensor.split(channels // num_heads, dim=1) + + checkpoint[path_map["query"]] = query.reshape(target_shape) + checkpoint[path_map["key"]] = key.reshape(target_shape) + checkpoint[path_map["value"]] = value.reshape(target_shape) + + for path in paths: + new_path = path["new"] + + # These have already been assigned + if attention_paths_to_split is not None and new_path in attention_paths_to_split: + continue + + # Global renaming happens here + new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") + new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") + new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") + + if additional_replacements is not None: + for replacement in additional_replacements: + new_path = new_path.replace(replacement["old"], replacement["new"]) + + # proj_attn.weight has to be converted from conv 1D to linear + if "proj_attn.weight" in new_path: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] + elif path["old"] in old_checkpoint: + checkpoint[new_path] = old_checkpoint[path["old"]] + + +def conv_attn_to_linear(checkpoint): + keys = list(checkpoint.keys()) + attn_keys = ["query.weight", "key.weight", "value.weight"] + for key in keys: + if ".".join(key.split(".")[-2:]) in attn_keys: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0, 0] + elif "proj_attn.weight" in key: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0] + + +def create_image_unet_diffusers_config(unet_params): + """ + Creates a config for the diffusers based on the config of the VD model. + """ + + block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult] + + down_block_types = [] + resolution = 1 + for i in range(len(block_out_channels)): + block_type = "CrossAttnDownBlock2D" if unet_params.with_attn[i] else "DownBlock2D" + down_block_types.append(block_type) + if i != len(block_out_channels) - 1: + resolution *= 2 + + up_block_types = [] + for i in range(len(block_out_channels)): + block_type = "CrossAttnUpBlock2D" if unet_params.with_attn[-i - 1] else "UpBlock2D" + up_block_types.append(block_type) + resolution //= 2 + + if not all(n == unet_params.num_noattn_blocks[0] for n in unet_params.num_noattn_blocks): + raise ValueError("Not all num_res_blocks are equal, which is not supported in this script.") + + config = dict( + sample_size=None, + in_channels=unet_params.input_channels, + out_channels=unet_params.output_channels, + down_block_types=tuple(down_block_types), + up_block_types=tuple(up_block_types), + block_out_channels=tuple(block_out_channels), + layers_per_block=unet_params.num_noattn_blocks[0], + cross_attention_dim=unet_params.context_dim, + attention_head_dim=unet_params.num_heads, + ) + + return config + + +def create_text_unet_diffusers_config(unet_params): + """ + Creates a config for the diffusers based on the config of the VD model. + """ + + block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult] + + down_block_types = [] + resolution = 1 + for i in range(len(block_out_channels)): + block_type = "CrossAttnDownBlockFlat" if unet_params.with_attn[i] else "DownBlockFlat" + down_block_types.append(block_type) + if i != len(block_out_channels) - 1: + resolution *= 2 + + up_block_types = [] + for i in range(len(block_out_channels)): + block_type = "CrossAttnUpBlockFlat" if unet_params.with_attn[-i - 1] else "UpBlockFlat" + up_block_types.append(block_type) + resolution //= 2 + + if not all(n == unet_params.num_noattn_blocks[0] for n in unet_params.num_noattn_blocks): + raise ValueError("Not all num_res_blocks are equal, which is not supported in this script.") + + config = dict( + sample_size=None, + in_channels=(unet_params.input_channels, 1, 1), + out_channels=(unet_params.output_channels, 1, 1), + down_block_types=tuple(down_block_types), + up_block_types=tuple(up_block_types), + block_out_channels=tuple(block_out_channels), + layers_per_block=unet_params.num_noattn_blocks[0], + cross_attention_dim=unet_params.context_dim, + attention_head_dim=unet_params.num_heads, + ) + + return config + + +def create_vae_diffusers_config(vae_params): + """ + Creates a config for the diffusers based on the config of the VD model. + """ + + block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult] + down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) + up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) + + config = dict( + sample_size=vae_params.resolution, + in_channels=vae_params.in_channels, + out_channels=vae_params.out_ch, + down_block_types=tuple(down_block_types), + up_block_types=tuple(up_block_types), + block_out_channels=tuple(block_out_channels), + latent_channels=vae_params.z_channels, + layers_per_block=vae_params.num_res_blocks, + ) + return config + + +def create_diffusers_scheduler(original_config): + schedular = DDIMScheduler( + num_train_timesteps=original_config.model.params.timesteps, + beta_start=original_config.model.params.linear_start, + beta_end=original_config.model.params.linear_end, + beta_schedule="scaled_linear", + ) + return schedular + + +def convert_vd_unet_checkpoint(checkpoint, config, unet_key, extract_ema=False): + """ + Takes a state dict and a config, and returns a converted checkpoint. + """ + + # extract state_dict for UNet + unet_state_dict = {} + keys = list(checkpoint.keys()) + + # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA + if sum(k.startswith("model_ema") for k in keys) > 100: + print("Checkpoint has both EMA and non-EMA weights.") + if extract_ema: + print( + "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA" + " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag." + ) + for key in keys: + if key.startswith("model.diffusion_model"): + flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) + else: + print( + "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA" + " weights (usually better for inference), please make sure to add the `--extract_ema` flag." + ) + + for key in keys: + if key.startswith(unet_key): + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) + + new_checkpoint = {} + + new_checkpoint["time_embedding.linear_1.weight"] = checkpoint["model.diffusion_model.time_embed.0.weight"] + new_checkpoint["time_embedding.linear_1.bias"] = checkpoint["model.diffusion_model.time_embed.0.bias"] + new_checkpoint["time_embedding.linear_2.weight"] = checkpoint["model.diffusion_model.time_embed.2.weight"] + new_checkpoint["time_embedding.linear_2.bias"] = checkpoint["model.diffusion_model.time_embed.2.bias"] + + new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] + new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] + + new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] + new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] + new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] + new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] + + # Retrieves the keys for the input blocks only + num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) + input_blocks = { + layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] + for layer_id in range(num_input_blocks) + } + + # Retrieves the keys for the middle blocks only + num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) + middle_blocks = { + layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] + for layer_id in range(num_middle_blocks) + } + + # Retrieves the keys for the output blocks only + num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) + output_blocks = { + layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] + for layer_id in range(num_output_blocks) + } + + for i in range(1, num_input_blocks): + block_id = (i - 1) // (config["layers_per_block"] + 1) + layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) + + resnets = [ + key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key + ] + attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] + + if f"input_blocks.{i}.0.op.weight" in unet_state_dict: + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.weight" + ) + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.bias" + ) + elif f"input_blocks.{i}.0.weight" in unet_state_dict: + # text_unet uses linear layers in place of downsamplers + shape = unet_state_dict[f"input_blocks.{i}.0.weight"].shape + if shape[0] != shape[1]: + continue + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.weight"] = unet_state_dict.pop( + f"input_blocks.{i}.0.weight" + ) + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.bias"] = unet_state_dict.pop( + f"input_blocks.{i}.0.bias" + ) + + paths = renew_resnet_paths(resnets) + meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + resnet_0 = middle_blocks[0] + attentions = middle_blocks[1] + resnet_1 = middle_blocks[2] + + resnet_0_paths = renew_resnet_paths(resnet_0) + assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) + + resnet_1_paths = renew_resnet_paths(resnet_1) + assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) + + attentions_paths = renew_attention_paths(attentions) + meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} + assign_to_checkpoint( + attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + for i in range(num_output_blocks): + block_id = i // (config["layers_per_block"] + 1) + layer_in_block_id = i % (config["layers_per_block"] + 1) + output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] + output_block_list = {} + + for layer in output_block_layers: + layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) + if layer_id in output_block_list: + output_block_list[layer_id].append(layer_name) + else: + output_block_list[layer_id] = [layer_name] + + if len(output_block_list) > 1: + resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] + attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] + + paths = renew_resnet_paths(resnets) + + meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + if ["conv.weight", "conv.bias"] in output_block_list.values(): + index = list(output_block_list.values()).index(["conv.weight", "conv.bias"]) + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.weight" + ] + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.bias" + ] + # Clear attentions as they have been attributed above. + if len(attentions) == 2: + attentions = [] + elif f"output_blocks.{i}.1.weight" in unet_state_dict: + # text_unet uses linear layers in place of upsamplers + shape = unet_state_dict[f"output_blocks.{i}.1.weight"].shape + if shape[0] != shape[1]: + continue + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.weight"] = unet_state_dict.pop( + f"output_blocks.{i}.1.weight" + ) + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.bias"] = unet_state_dict.pop( + f"output_blocks.{i}.1.bias" + ) + # Clear attentions as they have been attributed above. + if len(attentions) == 2: + attentions = [] + elif f"output_blocks.{i}.2.weight" in unet_state_dict: + # text_unet uses linear layers in place of upsamplers + shape = unet_state_dict[f"output_blocks.{i}.2.weight"].shape + if shape[0] != shape[1]: + continue + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.weight"] = unet_state_dict.pop( + f"output_blocks.{i}.2.weight" + ) + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.bias"] = unet_state_dict.pop( + f"output_blocks.{i}.2.bias" + ) + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = { + "old": f"output_blocks.{i}.1", + "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", + } + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + else: + resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) + for path in resnet_0_paths: + old_path = ".".join(["output_blocks", str(i), path["old"]]) + new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) + + new_checkpoint[new_path] = unet_state_dict[old_path] + + return new_checkpoint + + +def convert_vd_vae_checkpoint(checkpoint, config): + # extract state dict for VAE + vae_state_dict = {} + keys = list(checkpoint.keys()) + for key in keys: + vae_state_dict[key] = checkpoint.get(key) + + new_checkpoint = {} + + new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] + new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] + new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] + new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] + new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] + new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] + + new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] + new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] + new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] + new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] + new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] + new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] + + new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] + new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] + new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] + new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] + + # Retrieves the keys for the encoder down blocks only + num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) + down_blocks = { + layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) + } + + # Retrieves the keys for the decoder up blocks only + num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) + up_blocks = { + layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) + } + + for i in range(num_down_blocks): + resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] + + if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.weight" + ) + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.bias" + ) + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + + for i in range(num_up_blocks): + block_id = num_up_blocks - 1 - i + resnets = [ + key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key + ] + + if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.weight" + ] + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.bias" + ] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + return new_checkpoint + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--unet_checkpoint_path", default=None, type=str, required=False, help="Path to the checkpoint to convert." + ) + parser.add_argument( + "--vae_checkpoint_path", default=None, type=str, required=False, help="Path to the checkpoint to convert." + ) + parser.add_argument( + "--optimus_checkpoint_path", default=None, type=str, required=False, help="Path to the checkpoint to convert." + ) + parser.add_argument( + "--scheduler_type", + default="pndm", + type=str, + help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim', 'euler', 'euler-ancest', 'dpm']", + ) + parser.add_argument( + "--extract_ema", + action="store_true", + help=( + "Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights" + " or not. Defaults to `False`. Add `--extract_ema` to extract the EMA weights. EMA weights usually yield" + " higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning." + ), + ) + parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") + + args = parser.parse_args() + + scheduler_config = SCHEDULER_CONFIG + + num_train_timesteps = scheduler_config.timesteps + beta_start = scheduler_config.beta_linear_start + beta_end = scheduler_config.beta_linear_end + if args.scheduler_type == "pndm": + scheduler = PNDMScheduler( + beta_end=beta_end, + beta_schedule="scaled_linear", + beta_start=beta_start, + num_train_timesteps=num_train_timesteps, + skip_prk_steps=True, + steps_offset=1, + ) + elif args.scheduler_type == "lms": + scheduler = LMSDiscreteScheduler(beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear") + elif args.scheduler_type == "euler": + scheduler = EulerDiscreteScheduler(beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear") + elif args.scheduler_type == "euler-ancestral": + scheduler = EulerAncestralDiscreteScheduler( + beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear" + ) + elif args.scheduler_type == "dpm": + scheduler = DPMSolverMultistepScheduler( + beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear" + ) + elif args.scheduler_type == "ddim": + scheduler = DDIMScheduler( + beta_start=beta_start, + beta_end=beta_end, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + steps_offset=1, + ) + else: + raise ValueError(f"Scheduler of type {args.scheduler_type} doesn't exist!") + + # Convert the UNet2DConditionModel models. + if args.unet_checkpoint_path is not None: + # image UNet + image_unet_config = create_image_unet_diffusers_config(IMAGE_UNET_CONFIG) + checkpoint = torch.load(args.unet_checkpoint_path) + converted_image_unet_checkpoint = convert_vd_unet_checkpoint( + checkpoint, image_unet_config, unet_key="model.diffusion_model.unet_image.", extract_ema=args.extract_ema + ) + image_unet = UNet2DConditionModel(**image_unet_config) + image_unet.load_state_dict(converted_image_unet_checkpoint) + + # text UNet + text_unet_config = create_text_unet_diffusers_config(TEXT_UNET_CONFIG) + converted_text_unet_checkpoint = convert_vd_unet_checkpoint( + checkpoint, text_unet_config, unet_key="model.diffusion_model.unet_text.", extract_ema=args.extract_ema + ) + text_unet = UNetFlatConditionModel(**text_unet_config) + text_unet.load_state_dict(converted_text_unet_checkpoint) + + # Convert the VAE model. + if args.vae_checkpoint_path is not None: + vae_config = create_vae_diffusers_config(AUTOENCODER_CONFIG) + checkpoint = torch.load(args.vae_checkpoint_path) + converted_vae_checkpoint = convert_vd_vae_checkpoint(checkpoint, vae_config) + + vae = AutoencoderKL(**vae_config) + vae.load_state_dict(converted_vae_checkpoint) + + tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") + image_feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-large-patch14") + text_encoder = CLIPTextModelWithProjection.from_pretrained("openai/clip-vit-large-patch14") + image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14") + + pipe = VersatileDiffusionPipeline( + scheduler=scheduler, + tokenizer=tokenizer, + image_feature_extractor=image_feature_extractor, + text_encoder=text_encoder, + image_encoder=image_encoder, + image_unet=image_unet, + text_unet=text_unet, + vae=vae, + ) + pipe.save_pretrained(args.dump_path) diff --git a/scripts/convert_vq_diffusion_to_diffusers.py b/scripts/convert_vq_diffusion_to_diffusers.py new file mode 100644 index 000000000000..85db67844acd --- /dev/null +++ b/scripts/convert_vq_diffusion_to_diffusers.py @@ -0,0 +1,925 @@ +""" +This script ports models from VQ-diffusion (https://github.com/microsoft/VQ-Diffusion) to diffusers. + +It currently only supports porting the ITHQ dataset. + +ITHQ dataset: +```sh +# From the root directory of diffusers. + +# Download the VQVAE checkpoint +$ wget https://facevcstandard.blob.core.windows.net/v-zhictang/Improved-VQ-Diffusion_model_release/ithq_vqvae.pth?sv=2020-10-02&st=2022-05-30T15%3A17%3A18Z&se=2030-05-31T15%3A17%3A00Z&sr=b&sp=r&sig=1jVavHFPpUjDs%2FTO1V3PTezaNbPp2Nx8MxiWI7y6fEY%3D -O ithq_vqvae.pth + +# Download the VQVAE config +# NOTE that in VQ-diffusion the documented file is `configs/ithq.yaml` but the target class +# `image_synthesis.modeling.codecs.image_codec.ema_vqvae.PatchVQVAE` +# loads `OUTPUT/pretrained_model/taming_dvae/config.yaml` +$ wget https://raw.githubusercontent.com/microsoft/VQ-Diffusion/main/OUTPUT/pretrained_model/taming_dvae/config.yaml -O ithq_vqvae.yaml + +# Download the main model checkpoint +$ wget https://facevcstandard.blob.core.windows.net/v-zhictang/Improved-VQ-Diffusion_model_release/ithq_learnable.pth?sv=2020-10-02&st=2022-05-30T10%3A22%3A06Z&se=2030-05-31T10%3A22%3A00Z&sr=b&sp=r&sig=GOE%2Bza02%2FPnGxYVOOPtwrTR4RA3%2F5NVgMxdW4kjaEZ8%3D -O ithq_learnable.pth + +# Download the main model config +$ wget https://raw.githubusercontent.com/microsoft/VQ-Diffusion/main/configs/ithq.yaml -O ithq.yaml + +# run the convert script +$ python ./scripts/convert_vq_diffusion_to_diffusers.py \ + --checkpoint_path ./ithq_learnable.pth \ + --original_config_file ./ithq.yaml \ + --vqvae_checkpoint_path ./ithq_vqvae.pth \ + --vqvae_original_config_file ./ithq_vqvae.yaml \ + --dump_path +``` +""" + +import argparse +import tempfile + +import torch + +import yaml +from accelerate import init_empty_weights, load_checkpoint_and_dispatch +from diffusers import Transformer2DModel, VQDiffusionPipeline, VQDiffusionScheduler, VQModel +from diffusers.pipelines.vq_diffusion.pipeline_vq_diffusion import LearnedClassifierFreeSamplingEmbeddings +from transformers import CLIPTextModel, CLIPTokenizer +from yaml.loader import FullLoader + + +try: + from omegaconf import OmegaConf +except ImportError: + raise ImportError( + "OmegaConf is required to convert the VQ Diffusion checkpoints. Please install it with `pip install" + " OmegaConf`." + ) + +# vqvae model + +PORTED_VQVAES = ["image_synthesis.modeling.codecs.image_codec.patch_vqgan.PatchVQGAN"] + + +def vqvae_model_from_original_config(original_config): + assert original_config.target in PORTED_VQVAES, f"{original_config.target} has not yet been ported to diffusers." + + original_config = original_config.params + + original_encoder_config = original_config.encoder_config.params + original_decoder_config = original_config.decoder_config.params + + in_channels = original_encoder_config.in_channels + out_channels = original_decoder_config.out_ch + + down_block_types = get_down_block_types(original_encoder_config) + up_block_types = get_up_block_types(original_decoder_config) + + assert original_encoder_config.ch == original_decoder_config.ch + assert original_encoder_config.ch_mult == original_decoder_config.ch_mult + block_out_channels = tuple( + [original_encoder_config.ch * a_ch_mult for a_ch_mult in original_encoder_config.ch_mult] + ) + + assert original_encoder_config.num_res_blocks == original_decoder_config.num_res_blocks + layers_per_block = original_encoder_config.num_res_blocks + + assert original_encoder_config.z_channels == original_decoder_config.z_channels + latent_channels = original_encoder_config.z_channels + + num_vq_embeddings = original_config.n_embed + + # Hard coded value for ResnetBlock.GoupNorm(num_groups) in VQ-diffusion + norm_num_groups = 32 + + e_dim = original_config.embed_dim + + model = VQModel( + in_channels=in_channels, + out_channels=out_channels, + down_block_types=down_block_types, + up_block_types=up_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + latent_channels=latent_channels, + num_vq_embeddings=num_vq_embeddings, + norm_num_groups=norm_num_groups, + vq_embed_dim=e_dim, + ) + + return model + + +def get_down_block_types(original_encoder_config): + attn_resolutions = coerce_attn_resolutions(original_encoder_config.attn_resolutions) + num_resolutions = len(original_encoder_config.ch_mult) + resolution = coerce_resolution(original_encoder_config.resolution) + + curr_res = resolution + down_block_types = [] + + for _ in range(num_resolutions): + if curr_res in attn_resolutions: + down_block_type = "AttnDownEncoderBlock2D" + else: + down_block_type = "DownEncoderBlock2D" + + down_block_types.append(down_block_type) + + curr_res = [r // 2 for r in curr_res] + + return down_block_types + + +def get_up_block_types(original_decoder_config): + attn_resolutions = coerce_attn_resolutions(original_decoder_config.attn_resolutions) + num_resolutions = len(original_decoder_config.ch_mult) + resolution = coerce_resolution(original_decoder_config.resolution) + + curr_res = [r // 2 ** (num_resolutions - 1) for r in resolution] + up_block_types = [] + + for _ in reversed(range(num_resolutions)): + if curr_res in attn_resolutions: + up_block_type = "AttnUpDecoderBlock2D" + else: + up_block_type = "UpDecoderBlock2D" + + up_block_types.append(up_block_type) + + curr_res = [r * 2 for r in curr_res] + + return up_block_types + + +def coerce_attn_resolutions(attn_resolutions): + attn_resolutions = OmegaConf.to_object(attn_resolutions) + attn_resolutions_ = [] + for ar in attn_resolutions: + if isinstance(ar, (list, tuple)): + attn_resolutions_.append(list(ar)) + else: + attn_resolutions_.append([ar, ar]) + return attn_resolutions_ + + +def coerce_resolution(resolution): + resolution = OmegaConf.to_object(resolution) + if isinstance(resolution, int): + resolution = [resolution, resolution] # H, W + elif isinstance(resolution, (tuple, list)): + resolution = list(resolution) + else: + raise ValueError("Unknown type of resolution:", resolution) + return resolution + + +# done vqvae model + +# vqvae checkpoint + + +def vqvae_original_checkpoint_to_diffusers_checkpoint(model, checkpoint): + diffusers_checkpoint = {} + + diffusers_checkpoint.update(vqvae_encoder_to_diffusers_checkpoint(model, checkpoint)) + + # quant_conv + + diffusers_checkpoint.update( + { + "quant_conv.weight": checkpoint["quant_conv.weight"], + "quant_conv.bias": checkpoint["quant_conv.bias"], + } + ) + + # quantize + diffusers_checkpoint.update({"quantize.embedding.weight": checkpoint["quantize.embedding"]}) + + # post_quant_conv + diffusers_checkpoint.update( + { + "post_quant_conv.weight": checkpoint["post_quant_conv.weight"], + "post_quant_conv.bias": checkpoint["post_quant_conv.bias"], + } + ) + + # decoder + diffusers_checkpoint.update(vqvae_decoder_to_diffusers_checkpoint(model, checkpoint)) + + return diffusers_checkpoint + + +def vqvae_encoder_to_diffusers_checkpoint(model, checkpoint): + diffusers_checkpoint = {} + + # conv_in + diffusers_checkpoint.update( + { + "encoder.conv_in.weight": checkpoint["encoder.conv_in.weight"], + "encoder.conv_in.bias": checkpoint["encoder.conv_in.bias"], + } + ) + + # down_blocks + for down_block_idx, down_block in enumerate(model.encoder.down_blocks): + diffusers_down_block_prefix = f"encoder.down_blocks.{down_block_idx}" + down_block_prefix = f"encoder.down.{down_block_idx}" + + # resnets + for resnet_idx, resnet in enumerate(down_block.resnets): + diffusers_resnet_prefix = f"{diffusers_down_block_prefix}.resnets.{resnet_idx}" + resnet_prefix = f"{down_block_prefix}.block.{resnet_idx}" + + diffusers_checkpoint.update( + vqvae_resnet_to_diffusers_checkpoint( + resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix + ) + ) + + # downsample + + # do not include the downsample when on the last down block + # There is no downsample on the last down block + if down_block_idx != len(model.encoder.down_blocks) - 1: + # There's a single downsample in the original checkpoint but a list of downsamples + # in the diffusers model. + diffusers_downsample_prefix = f"{diffusers_down_block_prefix}.downsamplers.0.conv" + downsample_prefix = f"{down_block_prefix}.downsample.conv" + diffusers_checkpoint.update( + { + f"{diffusers_downsample_prefix}.weight": checkpoint[f"{downsample_prefix}.weight"], + f"{diffusers_downsample_prefix}.bias": checkpoint[f"{downsample_prefix}.bias"], + } + ) + + # attentions + + if hasattr(down_block, "attentions"): + for attention_idx, _ in enumerate(down_block.attentions): + diffusers_attention_prefix = f"{diffusers_down_block_prefix}.attentions.{attention_idx}" + attention_prefix = f"{down_block_prefix}.attn.{attention_idx}" + diffusers_checkpoint.update( + vqvae_attention_to_diffusers_checkpoint( + checkpoint, + diffusers_attention_prefix=diffusers_attention_prefix, + attention_prefix=attention_prefix, + ) + ) + + # mid block + + # mid block attentions + + # There is a single hardcoded attention block in the middle of the VQ-diffusion encoder + diffusers_attention_prefix = "encoder.mid_block.attentions.0" + attention_prefix = "encoder.mid.attn_1" + diffusers_checkpoint.update( + vqvae_attention_to_diffusers_checkpoint( + checkpoint, diffusers_attention_prefix=diffusers_attention_prefix, attention_prefix=attention_prefix + ) + ) + + # mid block resnets + + for diffusers_resnet_idx, resnet in enumerate(model.encoder.mid_block.resnets): + diffusers_resnet_prefix = f"encoder.mid_block.resnets.{diffusers_resnet_idx}" + + # the hardcoded prefixes to `block_` are 1 and 2 + orig_resnet_idx = diffusers_resnet_idx + 1 + # There are two hardcoded resnets in the middle of the VQ-diffusion encoder + resnet_prefix = f"encoder.mid.block_{orig_resnet_idx}" + + diffusers_checkpoint.update( + vqvae_resnet_to_diffusers_checkpoint( + resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix + ) + ) + + diffusers_checkpoint.update( + { + # conv_norm_out + "encoder.conv_norm_out.weight": checkpoint["encoder.norm_out.weight"], + "encoder.conv_norm_out.bias": checkpoint["encoder.norm_out.bias"], + # conv_out + "encoder.conv_out.weight": checkpoint["encoder.conv_out.weight"], + "encoder.conv_out.bias": checkpoint["encoder.conv_out.bias"], + } + ) + + return diffusers_checkpoint + + +def vqvae_decoder_to_diffusers_checkpoint(model, checkpoint): + diffusers_checkpoint = {} + + # conv in + diffusers_checkpoint.update( + { + "decoder.conv_in.weight": checkpoint["decoder.conv_in.weight"], + "decoder.conv_in.bias": checkpoint["decoder.conv_in.bias"], + } + ) + + # up_blocks + + for diffusers_up_block_idx, up_block in enumerate(model.decoder.up_blocks): + # up_blocks are stored in reverse order in the VQ-diffusion checkpoint + orig_up_block_idx = len(model.decoder.up_blocks) - 1 - diffusers_up_block_idx + + diffusers_up_block_prefix = f"decoder.up_blocks.{diffusers_up_block_idx}" + up_block_prefix = f"decoder.up.{orig_up_block_idx}" + + # resnets + for resnet_idx, resnet in enumerate(up_block.resnets): + diffusers_resnet_prefix = f"{diffusers_up_block_prefix}.resnets.{resnet_idx}" + resnet_prefix = f"{up_block_prefix}.block.{resnet_idx}" + + diffusers_checkpoint.update( + vqvae_resnet_to_diffusers_checkpoint( + resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix + ) + ) + + # upsample + + # there is no up sample on the last up block + if diffusers_up_block_idx != len(model.decoder.up_blocks) - 1: + # There's a single upsample in the VQ-diffusion checkpoint but a list of downsamples + # in the diffusers model. + diffusers_downsample_prefix = f"{diffusers_up_block_prefix}.upsamplers.0.conv" + downsample_prefix = f"{up_block_prefix}.upsample.conv" + diffusers_checkpoint.update( + { + f"{diffusers_downsample_prefix}.weight": checkpoint[f"{downsample_prefix}.weight"], + f"{diffusers_downsample_prefix}.bias": checkpoint[f"{downsample_prefix}.bias"], + } + ) + + # attentions + + if hasattr(up_block, "attentions"): + for attention_idx, _ in enumerate(up_block.attentions): + diffusers_attention_prefix = f"{diffusers_up_block_prefix}.attentions.{attention_idx}" + attention_prefix = f"{up_block_prefix}.attn.{attention_idx}" + diffusers_checkpoint.update( + vqvae_attention_to_diffusers_checkpoint( + checkpoint, + diffusers_attention_prefix=diffusers_attention_prefix, + attention_prefix=attention_prefix, + ) + ) + + # mid block + + # mid block attentions + + # There is a single hardcoded attention block in the middle of the VQ-diffusion decoder + diffusers_attention_prefix = "decoder.mid_block.attentions.0" + attention_prefix = "decoder.mid.attn_1" + diffusers_checkpoint.update( + vqvae_attention_to_diffusers_checkpoint( + checkpoint, diffusers_attention_prefix=diffusers_attention_prefix, attention_prefix=attention_prefix + ) + ) + + # mid block resnets + + for diffusers_resnet_idx, resnet in enumerate(model.encoder.mid_block.resnets): + diffusers_resnet_prefix = f"decoder.mid_block.resnets.{diffusers_resnet_idx}" + + # the hardcoded prefixes to `block_` are 1 and 2 + orig_resnet_idx = diffusers_resnet_idx + 1 + # There are two hardcoded resnets in the middle of the VQ-diffusion decoder + resnet_prefix = f"decoder.mid.block_{orig_resnet_idx}" + + diffusers_checkpoint.update( + vqvae_resnet_to_diffusers_checkpoint( + resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix + ) + ) + + diffusers_checkpoint.update( + { + # conv_norm_out + "decoder.conv_norm_out.weight": checkpoint["decoder.norm_out.weight"], + "decoder.conv_norm_out.bias": checkpoint["decoder.norm_out.bias"], + # conv_out + "decoder.conv_out.weight": checkpoint["decoder.conv_out.weight"], + "decoder.conv_out.bias": checkpoint["decoder.conv_out.bias"], + } + ) + + return diffusers_checkpoint + + +def vqvae_resnet_to_diffusers_checkpoint(resnet, checkpoint, *, diffusers_resnet_prefix, resnet_prefix): + rv = { + # norm1 + f"{diffusers_resnet_prefix}.norm1.weight": checkpoint[f"{resnet_prefix}.norm1.weight"], + f"{diffusers_resnet_prefix}.norm1.bias": checkpoint[f"{resnet_prefix}.norm1.bias"], + # conv1 + f"{diffusers_resnet_prefix}.conv1.weight": checkpoint[f"{resnet_prefix}.conv1.weight"], + f"{diffusers_resnet_prefix}.conv1.bias": checkpoint[f"{resnet_prefix}.conv1.bias"], + # norm2 + f"{diffusers_resnet_prefix}.norm2.weight": checkpoint[f"{resnet_prefix}.norm2.weight"], + f"{diffusers_resnet_prefix}.norm2.bias": checkpoint[f"{resnet_prefix}.norm2.bias"], + # conv2 + f"{diffusers_resnet_prefix}.conv2.weight": checkpoint[f"{resnet_prefix}.conv2.weight"], + f"{diffusers_resnet_prefix}.conv2.bias": checkpoint[f"{resnet_prefix}.conv2.bias"], + } + + if resnet.conv_shortcut is not None: + rv.update( + { + f"{diffusers_resnet_prefix}.conv_shortcut.weight": checkpoint[f"{resnet_prefix}.nin_shortcut.weight"], + f"{diffusers_resnet_prefix}.conv_shortcut.bias": checkpoint[f"{resnet_prefix}.nin_shortcut.bias"], + } + ) + + return rv + + +def vqvae_attention_to_diffusers_checkpoint(checkpoint, *, diffusers_attention_prefix, attention_prefix): + return { + # group_norm + f"{diffusers_attention_prefix}.group_norm.weight": checkpoint[f"{attention_prefix}.norm.weight"], + f"{diffusers_attention_prefix}.group_norm.bias": checkpoint[f"{attention_prefix}.norm.bias"], + # query + f"{diffusers_attention_prefix}.query.weight": checkpoint[f"{attention_prefix}.q.weight"][:, :, 0, 0], + f"{diffusers_attention_prefix}.query.bias": checkpoint[f"{attention_prefix}.q.bias"], + # key + f"{diffusers_attention_prefix}.key.weight": checkpoint[f"{attention_prefix}.k.weight"][:, :, 0, 0], + f"{diffusers_attention_prefix}.key.bias": checkpoint[f"{attention_prefix}.k.bias"], + # value + f"{diffusers_attention_prefix}.value.weight": checkpoint[f"{attention_prefix}.v.weight"][:, :, 0, 0], + f"{diffusers_attention_prefix}.value.bias": checkpoint[f"{attention_prefix}.v.bias"], + # proj_attn + f"{diffusers_attention_prefix}.proj_attn.weight": checkpoint[f"{attention_prefix}.proj_out.weight"][ + :, :, 0, 0 + ], + f"{diffusers_attention_prefix}.proj_attn.bias": checkpoint[f"{attention_prefix}.proj_out.bias"], + } + + +# done vqvae checkpoint + +# transformer model + +PORTED_DIFFUSIONS = ["image_synthesis.modeling.transformers.diffusion_transformer.DiffusionTransformer"] +PORTED_TRANSFORMERS = ["image_synthesis.modeling.transformers.transformer_utils.Text2ImageTransformer"] +PORTED_CONTENT_EMBEDDINGS = ["image_synthesis.modeling.embeddings.dalle_mask_image_embedding.DalleMaskImageEmbedding"] + + +def transformer_model_from_original_config( + original_diffusion_config, original_transformer_config, original_content_embedding_config +): + assert ( + original_diffusion_config.target in PORTED_DIFFUSIONS + ), f"{original_diffusion_config.target} has not yet been ported to diffusers." + assert ( + original_transformer_config.target in PORTED_TRANSFORMERS + ), f"{original_transformer_config.target} has not yet been ported to diffusers." + assert ( + original_content_embedding_config.target in PORTED_CONTENT_EMBEDDINGS + ), f"{original_content_embedding_config.target} has not yet been ported to diffusers." + + original_diffusion_config = original_diffusion_config.params + original_transformer_config = original_transformer_config.params + original_content_embedding_config = original_content_embedding_config.params + + inner_dim = original_transformer_config["n_embd"] + + n_heads = original_transformer_config["n_head"] + + # VQ-Diffusion gives dimension of the multi-headed attention layers as the + # number of attention heads times the sequence length (the dimension) of a + # single head. We want to specify our attention blocks with those values + # specified separately + assert inner_dim % n_heads == 0 + d_head = inner_dim // n_heads + + depth = original_transformer_config["n_layer"] + context_dim = original_transformer_config["condition_dim"] + + num_embed = original_content_embedding_config["num_embed"] + # the number of embeddings in the transformer includes the mask embedding. + # the content embedding (the vqvae) does not include the mask embedding. + num_embed = num_embed + 1 + + height = original_transformer_config["content_spatial_size"][0] + width = original_transformer_config["content_spatial_size"][1] + + assert width == height, "width has to be equal to height" + dropout = original_transformer_config["resid_pdrop"] + num_embeds_ada_norm = original_diffusion_config["diffusion_step"] + + model_kwargs = { + "attention_bias": True, + "cross_attention_dim": context_dim, + "attention_head_dim": d_head, + "num_layers": depth, + "dropout": dropout, + "num_attention_heads": n_heads, + "num_vector_embeds": num_embed, + "num_embeds_ada_norm": num_embeds_ada_norm, + "norm_num_groups": 32, + "sample_size": width, + "activation_fn": "geglu-approximate", + } + + model = Transformer2DModel(**model_kwargs) + return model + + +# done transformer model + +# transformer checkpoint + + +def transformer_original_checkpoint_to_diffusers_checkpoint(model, checkpoint): + diffusers_checkpoint = {} + + transformer_prefix = "transformer.transformer" + + diffusers_latent_image_embedding_prefix = "latent_image_embedding" + latent_image_embedding_prefix = f"{transformer_prefix}.content_emb" + + # DalleMaskImageEmbedding + diffusers_checkpoint.update( + { + f"{diffusers_latent_image_embedding_prefix}.emb.weight": checkpoint[ + f"{latent_image_embedding_prefix}.emb.weight" + ], + f"{diffusers_latent_image_embedding_prefix}.height_emb.weight": checkpoint[ + f"{latent_image_embedding_prefix}.height_emb.weight" + ], + f"{diffusers_latent_image_embedding_prefix}.width_emb.weight": checkpoint[ + f"{latent_image_embedding_prefix}.width_emb.weight" + ], + } + ) + + # transformer blocks + for transformer_block_idx, transformer_block in enumerate(model.transformer_blocks): + diffusers_transformer_block_prefix = f"transformer_blocks.{transformer_block_idx}" + transformer_block_prefix = f"{transformer_prefix}.blocks.{transformer_block_idx}" + + # ada norm block + diffusers_ada_norm_prefix = f"{diffusers_transformer_block_prefix}.norm1" + ada_norm_prefix = f"{transformer_block_prefix}.ln1" + + diffusers_checkpoint.update( + transformer_ada_norm_to_diffusers_checkpoint( + checkpoint, diffusers_ada_norm_prefix=diffusers_ada_norm_prefix, ada_norm_prefix=ada_norm_prefix + ) + ) + + # attention block + diffusers_attention_prefix = f"{diffusers_transformer_block_prefix}.attn1" + attention_prefix = f"{transformer_block_prefix}.attn1" + + diffusers_checkpoint.update( + transformer_attention_to_diffusers_checkpoint( + checkpoint, diffusers_attention_prefix=diffusers_attention_prefix, attention_prefix=attention_prefix + ) + ) + + # ada norm block + diffusers_ada_norm_prefix = f"{diffusers_transformer_block_prefix}.norm2" + ada_norm_prefix = f"{transformer_block_prefix}.ln1_1" + + diffusers_checkpoint.update( + transformer_ada_norm_to_diffusers_checkpoint( + checkpoint, diffusers_ada_norm_prefix=diffusers_ada_norm_prefix, ada_norm_prefix=ada_norm_prefix + ) + ) + + # attention block + diffusers_attention_prefix = f"{diffusers_transformer_block_prefix}.attn2" + attention_prefix = f"{transformer_block_prefix}.attn2" + + diffusers_checkpoint.update( + transformer_attention_to_diffusers_checkpoint( + checkpoint, diffusers_attention_prefix=diffusers_attention_prefix, attention_prefix=attention_prefix + ) + ) + + # norm block + diffusers_norm_block_prefix = f"{diffusers_transformer_block_prefix}.norm3" + norm_block_prefix = f"{transformer_block_prefix}.ln2" + + diffusers_checkpoint.update( + { + f"{diffusers_norm_block_prefix}.weight": checkpoint[f"{norm_block_prefix}.weight"], + f"{diffusers_norm_block_prefix}.bias": checkpoint[f"{norm_block_prefix}.bias"], + } + ) + + # feedforward block + diffusers_feedforward_prefix = f"{diffusers_transformer_block_prefix}.ff" + feedforward_prefix = f"{transformer_block_prefix}.mlp" + + diffusers_checkpoint.update( + transformer_feedforward_to_diffusers_checkpoint( + checkpoint, + diffusers_feedforward_prefix=diffusers_feedforward_prefix, + feedforward_prefix=feedforward_prefix, + ) + ) + + # to logits + + diffusers_norm_out_prefix = "norm_out" + norm_out_prefix = f"{transformer_prefix}.to_logits.0" + + diffusers_checkpoint.update( + { + f"{diffusers_norm_out_prefix}.weight": checkpoint[f"{norm_out_prefix}.weight"], + f"{diffusers_norm_out_prefix}.bias": checkpoint[f"{norm_out_prefix}.bias"], + } + ) + + diffusers_out_prefix = "out" + out_prefix = f"{transformer_prefix}.to_logits.1" + + diffusers_checkpoint.update( + { + f"{diffusers_out_prefix}.weight": checkpoint[f"{out_prefix}.weight"], + f"{diffusers_out_prefix}.bias": checkpoint[f"{out_prefix}.bias"], + } + ) + + return diffusers_checkpoint + + +def transformer_ada_norm_to_diffusers_checkpoint(checkpoint, *, diffusers_ada_norm_prefix, ada_norm_prefix): + return { + f"{diffusers_ada_norm_prefix}.emb.weight": checkpoint[f"{ada_norm_prefix}.emb.weight"], + f"{diffusers_ada_norm_prefix}.linear.weight": checkpoint[f"{ada_norm_prefix}.linear.weight"], + f"{diffusers_ada_norm_prefix}.linear.bias": checkpoint[f"{ada_norm_prefix}.linear.bias"], + } + + +def transformer_attention_to_diffusers_checkpoint(checkpoint, *, diffusers_attention_prefix, attention_prefix): + return { + # key + f"{diffusers_attention_prefix}.to_k.weight": checkpoint[f"{attention_prefix}.key.weight"], + f"{diffusers_attention_prefix}.to_k.bias": checkpoint[f"{attention_prefix}.key.bias"], + # query + f"{diffusers_attention_prefix}.to_q.weight": checkpoint[f"{attention_prefix}.query.weight"], + f"{diffusers_attention_prefix}.to_q.bias": checkpoint[f"{attention_prefix}.query.bias"], + # value + f"{diffusers_attention_prefix}.to_v.weight": checkpoint[f"{attention_prefix}.value.weight"], + f"{diffusers_attention_prefix}.to_v.bias": checkpoint[f"{attention_prefix}.value.bias"], + # linear out + f"{diffusers_attention_prefix}.to_out.0.weight": checkpoint[f"{attention_prefix}.proj.weight"], + f"{diffusers_attention_prefix}.to_out.0.bias": checkpoint[f"{attention_prefix}.proj.bias"], + } + + +def transformer_feedforward_to_diffusers_checkpoint(checkpoint, *, diffusers_feedforward_prefix, feedforward_prefix): + return { + f"{diffusers_feedforward_prefix}.net.0.proj.weight": checkpoint[f"{feedforward_prefix}.0.weight"], + f"{diffusers_feedforward_prefix}.net.0.proj.bias": checkpoint[f"{feedforward_prefix}.0.bias"], + f"{diffusers_feedforward_prefix}.net.2.weight": checkpoint[f"{feedforward_prefix}.2.weight"], + f"{diffusers_feedforward_prefix}.net.2.bias": checkpoint[f"{feedforward_prefix}.2.bias"], + } + + +# done transformer checkpoint + + +def read_config_file(filename): + # The yaml file contains annotations that certain values should + # loaded as tuples. By default, OmegaConf will panic when reading + # these. Instead, we can manually read the yaml with the FullLoader and then + # construct the OmegaConf object. + with open(filename) as f: + original_config = yaml.load(f, FullLoader) + + return OmegaConf.create(original_config) + + +# We take separate arguments for the vqvae because the ITHQ vqvae config file +# is separate from the config file for the rest of the model. +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--vqvae_checkpoint_path", + default=None, + type=str, + required=True, + help="Path to the vqvae checkpoint to convert.", + ) + + parser.add_argument( + "--vqvae_original_config_file", + default=None, + type=str, + required=True, + help="The YAML config file corresponding to the original architecture for the vqvae.", + ) + + parser.add_argument( + "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert." + ) + + parser.add_argument( + "--original_config_file", + default=None, + type=str, + required=True, + help="The YAML config file corresponding to the original architecture.", + ) + + parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") + + parser.add_argument( + "--checkpoint_load_device", + default="cpu", + type=str, + required=False, + help="The device passed to `map_location` when loading checkpoints.", + ) + + # See link for how ema weights are always selected + # https://github.com/microsoft/VQ-Diffusion/blob/3c98e77f721db7c787b76304fa2c96a36c7b00af/inference_VQ_Diffusion.py#L65 + parser.add_argument( + "--no_use_ema", + action="store_true", + required=False, + help=( + "Set to not use the ema weights from the original VQ-Diffusion checkpoint. You probably do not want to set" + " it as the original VQ-Diffusion always uses the ema weights when loading models." + ), + ) + + args = parser.parse_args() + + use_ema = not args.no_use_ema + + print(f"loading checkpoints to {args.checkpoint_load_device}") + + checkpoint_map_location = torch.device(args.checkpoint_load_device) + + # vqvae_model + + print(f"loading vqvae, config: {args.vqvae_original_config_file}, checkpoint: {args.vqvae_checkpoint_path}") + + vqvae_original_config = read_config_file(args.vqvae_original_config_file).model + vqvae_checkpoint = torch.load(args.vqvae_checkpoint_path, map_location=checkpoint_map_location)["model"] + + with init_empty_weights(): + vqvae_model = vqvae_model_from_original_config(vqvae_original_config) + + vqvae_diffusers_checkpoint = vqvae_original_checkpoint_to_diffusers_checkpoint(vqvae_model, vqvae_checkpoint) + + with tempfile.NamedTemporaryFile() as vqvae_diffusers_checkpoint_file: + torch.save(vqvae_diffusers_checkpoint, vqvae_diffusers_checkpoint_file.name) + del vqvae_diffusers_checkpoint + del vqvae_checkpoint + load_checkpoint_and_dispatch(vqvae_model, vqvae_diffusers_checkpoint_file.name, device_map="auto") + + print("done loading vqvae") + + # done vqvae_model + + # transformer_model + + print( + f"loading transformer, config: {args.original_config_file}, checkpoint: {args.checkpoint_path}, use ema:" + f" {use_ema}" + ) + + original_config = read_config_file(args.original_config_file).model + + diffusion_config = original_config.params.diffusion_config + transformer_config = original_config.params.diffusion_config.params.transformer_config + content_embedding_config = original_config.params.diffusion_config.params.content_emb_config + + pre_checkpoint = torch.load(args.checkpoint_path, map_location=checkpoint_map_location) + + if use_ema: + if "ema" in pre_checkpoint: + checkpoint = {} + for k, v in pre_checkpoint["model"].items(): + checkpoint[k] = v + + for k, v in pre_checkpoint["ema"].items(): + # The ema weights are only used on the transformer. To mimic their key as if they came + # from the state_dict for the top level model, we prefix with an additional "transformer." + # See the source linked in the args.use_ema config for more information. + checkpoint[f"transformer.{k}"] = v + else: + print("attempted to load ema weights but no ema weights are specified in the loaded checkpoint.") + checkpoint = pre_checkpoint["model"] + else: + checkpoint = pre_checkpoint["model"] + + del pre_checkpoint + + with init_empty_weights(): + transformer_model = transformer_model_from_original_config( + diffusion_config, transformer_config, content_embedding_config + ) + + diffusers_transformer_checkpoint = transformer_original_checkpoint_to_diffusers_checkpoint( + transformer_model, checkpoint + ) + + # classifier free sampling embeddings interlude + + # The learned embeddings are stored on the transformer in the original VQ-diffusion. We store them on a separate + # model, so we pull them off the checkpoint before the checkpoint is deleted. + + learnable_classifier_free_sampling_embeddings = diffusion_config.params.learnable_cf + + if learnable_classifier_free_sampling_embeddings: + learned_classifier_free_sampling_embeddings_embeddings = checkpoint["transformer.empty_text_embed"] + else: + learned_classifier_free_sampling_embeddings_embeddings = None + + # done classifier free sampling embeddings interlude + + with tempfile.NamedTemporaryFile() as diffusers_transformer_checkpoint_file: + torch.save(diffusers_transformer_checkpoint, diffusers_transformer_checkpoint_file.name) + del diffusers_transformer_checkpoint + del checkpoint + load_checkpoint_and_dispatch(transformer_model, diffusers_transformer_checkpoint_file.name, device_map="auto") + + print("done loading transformer") + + # done transformer_model + + # text encoder + + print("loading CLIP text encoder") + + clip_name = "openai/clip-vit-base-patch32" + + # The original VQ-Diffusion specifies the pad value by the int used in the + # returned tokens. Each model uses `0` as the pad value. The transformers clip api + # specifies the pad value via the token before it has been tokenized. The `!` pad + # token is the same as padding with the `0` pad value. + pad_token = "!" + + tokenizer_model = CLIPTokenizer.from_pretrained(clip_name, pad_token=pad_token, device_map="auto") + + assert tokenizer_model.convert_tokens_to_ids(pad_token) == 0 + + text_encoder_model = CLIPTextModel.from_pretrained( + clip_name, + # `CLIPTextModel` does not support device_map="auto" + # device_map="auto" + ) + + print("done loading CLIP text encoder") + + # done text encoder + + # scheduler + + scheduler_model = VQDiffusionScheduler( + # the scheduler has the same number of embeddings as the transformer + num_vec_classes=transformer_model.num_vector_embeds + ) + + # done scheduler + + # learned classifier free sampling embeddings + + with init_empty_weights(): + learned_classifier_free_sampling_embeddings_model = LearnedClassifierFreeSamplingEmbeddings( + learnable_classifier_free_sampling_embeddings, + hidden_size=text_encoder_model.config.hidden_size, + length=tokenizer_model.model_max_length, + ) + + learned_classifier_free_sampling_checkpoint = { + "embeddings": learned_classifier_free_sampling_embeddings_embeddings.float() + } + + with tempfile.NamedTemporaryFile() as learned_classifier_free_sampling_checkpoint_file: + torch.save(learned_classifier_free_sampling_checkpoint, learned_classifier_free_sampling_checkpoint_file.name) + del learned_classifier_free_sampling_checkpoint + del learned_classifier_free_sampling_embeddings_embeddings + load_checkpoint_and_dispatch( + learned_classifier_free_sampling_embeddings_model, + learned_classifier_free_sampling_checkpoint_file.name, + device_map="auto", + ) + + # done learned classifier free sampling embeddings + + print(f"saving VQ diffusion model, path: {args.dump_path}") + + pipe = VQDiffusionPipeline( + vqvae=vqvae_model, + transformer=transformer_model, + tokenizer=tokenizer_model, + text_encoder=text_encoder_model, + learned_classifier_free_sampling_embeddings=learned_classifier_free_sampling_embeddings_model, + scheduler=scheduler_model, + ) + pipe.save_pretrained(args.dump_path) + + print("done writing VQ diffusion model") diff --git a/setup.py b/setup.py index 25e434de2801..9148acce2613 100644 --- a/setup.py +++ b/setup.py @@ -78,7 +78,7 @@ # 1. all dependencies should be listed here with their version requirements if any # 2. once modified, run: `make deps_table_update` to update src/diffusers/dependency_versions_table.py _deps = [ - "Pillow<10.0", # keep the PIL.Image.Resampling deprecation away + "Pillow", # keep the PIL.Image.Resampling deprecation away "accelerate>=0.11.0", "black==22.8", "datasets", @@ -89,14 +89,15 @@ "huggingface-hub>=0.10.0", "importlib_metadata", "isort>=5.5.4", - "jax>=0.2.8,!=0.3.2,<=0.3.6", - "jaxlib>=0.1.65,<=0.3.6", + "jax>=0.2.8,!=0.3.2", + "jaxlib>=0.1.65", "modelcards>=0.1.4", "numpy", - "onnxruntime", + "parameterized", "pytest", "pytest-timeout", "pytest-xdist", + "sentencepiece>=0.1.91,!=0.1.92", "scipy", "regex!=2019.12.17", "requests", @@ -178,17 +179,17 @@ def run(self): extras["docs"] = deps_list("hf-doc-builder") extras["training"] = deps_list("accelerate", "datasets", "tensorboard", "modelcards") extras["test"] = deps_list( - "accelerate", "datasets", - "onnxruntime", + "parameterized", "pytest", "pytest-timeout", "pytest-xdist", + "sentencepiece", "scipy", "torchvision", "transformers" ) -extras["torch"] = deps_list("torch") +extras["torch"] = deps_list("torch", "accelerate") if os.name == "nt": # windows extras["flax"] = [] # jax is not supported on windows @@ -211,7 +212,7 @@ def run(self): setup( name="diffusers", - version="0.6.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) + version="0.9.0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) description="Diffusers", long_description=open("README.md", "r", encoding="utf-8").read(), long_description_content_type="text/markdown", diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 0a0b0b4965dd..256eb8fee8bc 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -9,7 +9,7 @@ ) -__version__ = "0.6.0.dev0" +__version__ = "0.9.0" from .configuration_utils import ConfigMixin from .onnx_utils import OnnxRuntimeModel @@ -18,7 +18,7 @@ if is_torch_available(): from .modeling_utils import ModelMixin - from .models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel + from .models import AutoencoderKL, Transformer2DModel, UNet1DModel, UNet2DConditionModel, UNet2DModel, VQModel from .optimization import ( get_constant_schedule, get_constant_schedule_with_warmup, @@ -29,14 +29,30 @@ get_scheduler, ) from .pipeline_utils import DiffusionPipeline - from .pipelines import DDIMPipeline, DDPMPipeline, KarrasVePipeline, LDMPipeline, PNDMPipeline, ScoreSdeVePipeline + from .pipelines import ( + DanceDiffusionPipeline, + DDIMPipeline, + DDPMPipeline, + KarrasVePipeline, + LDMPipeline, + LDMSuperResolutionPipeline, + PNDMPipeline, + RePaintPipeline, + ScoreSdeVePipeline, + ) from .schedulers import ( DDIMScheduler, DDPMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + IPNDMScheduler, KarrasVeScheduler, PNDMScheduler, + RePaintScheduler, SchedulerMixin, ScoreSdeVeScheduler, + VQDiffusionScheduler, ) from .training_utils import EMAModel else: @@ -49,16 +65,34 @@ if is_torch_available() and is_transformers_available(): from .pipelines import ( + AltDiffusionImg2ImgPipeline, + AltDiffusionPipeline, + CycleDiffusionPipeline, LDMTextToImagePipeline, + StableDiffusionImageVariationPipeline, StableDiffusionImg2ImgPipeline, StableDiffusionInpaintPipeline, + StableDiffusionInpaintPipelineLegacy, StableDiffusionPipeline, + StableDiffusionPipelineSafe, + StableDiffusionUpscalePipeline, + VersatileDiffusionDualGuidedPipeline, + VersatileDiffusionImageVariationPipeline, + VersatileDiffusionPipeline, + VersatileDiffusionTextToImagePipeline, + VQDiffusionPipeline, ) else: from .utils.dummy_torch_and_transformers_objects import * # noqa F403 if is_torch_available() and is_transformers_available() and is_onnx_available(): - from .pipelines import StableDiffusionOnnxPipeline + from .pipelines import ( + OnnxStableDiffusionImg2ImgPipeline, + OnnxStableDiffusionInpaintPipeline, + OnnxStableDiffusionInpaintPipelineLegacy, + OnnxStableDiffusionPipeline, + StableDiffusionOnnxPipeline, + ) else: from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403 @@ -70,6 +104,7 @@ from .schedulers import ( FlaxDDIMScheduler, FlaxDDPMScheduler, + FlaxDPMSolverMultistepScheduler, FlaxKarrasVeScheduler, FlaxLMSDiscreteScheduler, FlaxPNDMScheduler, diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index 6f866c6b5fff..f06586b23698 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -13,9 +13,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" ConfigMixinuration base class and utilities.""" +""" ConfigMixin base class and utilities.""" import dataclasses import functools +import importlib import inspect import json import os @@ -28,7 +29,7 @@ from requests import HTTPError from . import __version__ -from .utils import DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging +from .utils import DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, DummyObject, deprecate, logging logger = logging.get_logger(__name__) @@ -36,6 +37,38 @@ _re_configuration_file = re.compile(r"config\.(.*)\.json") +class FrozenDict(OrderedDict): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + for key, value in self.items(): + setattr(self, key, value) + + self.__frozen = True + + def __delitem__(self, *args, **kwargs): + raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.") + + def setdefault(self, *args, **kwargs): + raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.") + + def pop(self, *args, **kwargs): + raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.") + + def update(self, *args, **kwargs): + raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.") + + def __setattr__(self, name, value): + if hasattr(self, "__frozen") and self.__frozen: + raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.") + super().__setattr__(name, value) + + def __setitem__(self, name, value): + if hasattr(self, "__frozen") and self.__frozen: + raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.") + super().__setitem__(name, value) + + class ConfigMixin: r""" Base class for all configuration classes. Stores all configuration parameters under `self.config` Also handles all @@ -47,17 +80,21 @@ class ConfigMixin: - **config_name** (`str`) -- A filename under which the config should stored when calling [`~ConfigMixin.save_config`] (should be overridden by parent class). - **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be - overridden by parent class). + overridden by subclass). + - **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by subclass). + - **_deprecated_kwargs** (`List[str]`) -- Keyword arguments that are deprecated. Note that the init function + should only have a `kwargs` argument if at least one argument is deprecated (should be overridden by + subclass). """ config_name = None ignore_for_config = [] + has_compatibles = False + + _deprecated_kwargs = [] def register_to_config(self, **kwargs): if self.config_name is None: raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`") - kwargs["_class_name"] = self.__class__.__name__ - kwargs["_diffusers_version"] = __version__ - # Special case for `kwargs` used in deprecation warning added to schedulers # TODO: remove this when we remove the deprecation warning, and the `kwargs` argument, # or solve in a more general way. @@ -96,12 +133,106 @@ def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool output_config_file = os.path.join(save_directory, self.config_name) self.to_json_file(output_config_file) - logger.info(f"ConfigMixinuration saved in {output_config_file}") + logger.info(f"Configuration saved in {output_config_file}") @classmethod - def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs): + def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs): r""" - Instantiate a Python class from a pre-defined JSON-file. + Instantiate a Python class from a config dictionary + + Parameters: + config (`Dict[str, Any]`): + A config dictionary from which the Python class will be instantiated. Make sure to only load + configuration files of compatible classes. + return_unused_kwargs (`bool`, *optional*, defaults to `False`): + Whether kwargs that are not consumed by the Python class should be returned or not. + + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to update the configuration object (after it being loaded) and initiate the Python class. + `**kwargs` will be directly passed to the underlying scheduler/model's `__init__` method and eventually + overwrite same named arguments of `config`. + + Examples: + + ```python + >>> from diffusers import DDPMScheduler, DDIMScheduler, PNDMScheduler + + >>> # Download scheduler from huggingface.co and cache. + >>> scheduler = DDPMScheduler.from_pretrained("google/ddpm-cifar10-32") + + >>> # Instantiate DDIM scheduler class with same config as DDPM + >>> scheduler = DDIMScheduler.from_config(scheduler.config) + + >>> # Instantiate PNDM scheduler class with same config as DDPM + >>> scheduler = PNDMScheduler.from_config(scheduler.config) + ``` + """ + # <===== TO BE REMOVED WITH DEPRECATION + # TODO(Patrick) - make sure to remove the following lines when config=="model_path" is deprecated + if "pretrained_model_name_or_path" in kwargs: + config = kwargs.pop("pretrained_model_name_or_path") + + if config is None: + raise ValueError("Please make sure to provide a config as the first positional argument.") + # ======> + + if not isinstance(config, dict): + deprecation_message = "It is deprecated to pass a pretrained model name or path to `from_config`." + if "Scheduler" in cls.__name__: + deprecation_message += ( + f"If you were trying to load a scheduler, please use {cls}.from_pretrained(...) instead." + " Otherwise, please make sure to pass a configuration dictionary instead. This functionality will" + " be removed in v1.0.0." + ) + elif "Model" in cls.__name__: + deprecation_message += ( + f"If you were trying to load a model, please use {cls}.load_config(...) followed by" + f" {cls}.from_config(...) instead. Otherwise, please make sure to pass a configuration dictionary" + " instead. This functionality will be removed in v1.0.0." + ) + deprecate("config-passed-as-path", "1.0.0", deprecation_message, standard_warn=False) + config, kwargs = cls.load_config(pretrained_model_name_or_path=config, return_unused_kwargs=True, **kwargs) + + init_dict, unused_kwargs, hidden_dict = cls.extract_init_dict(config, **kwargs) + + # Allow dtype to be specified on initialization + if "dtype" in unused_kwargs: + init_dict["dtype"] = unused_kwargs.pop("dtype") + + # add possible deprecated kwargs + for deprecated_kwarg in cls._deprecated_kwargs: + if deprecated_kwarg in unused_kwargs: + init_dict[deprecated_kwarg] = unused_kwargs.pop(deprecated_kwarg) + + # Return model and optionally state and/or unused_kwargs + model = cls(**init_dict) + + # make sure to also save config parameters that might be used for compatible classes + model.register_to_config(**hidden_dict) + + # add hidden kwargs of compatible classes to unused_kwargs + unused_kwargs = {**unused_kwargs, **hidden_dict} + + if return_unused_kwargs: + return (model, unused_kwargs) + else: + return model + + @classmethod + def get_config_dict(cls, *args, **kwargs): + deprecation_message = ( + f" The function get_config_dict is deprecated. Please use {cls}.load_config instead. This function will be" + " removed in version v1.0.0" + ) + deprecate("get_config_dict", "1.0.0", deprecation_message, standard_warn=False) + return cls.load_config(*args, **kwargs) + + @classmethod + def load_config( + cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs + ) -> Tuple[Dict[str, Any], Dict[str, Any]]: + r""" + Instantiate a Python class from a config dictionary Parameters: pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): @@ -115,10 +246,6 @@ def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], ret cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. - ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`): - Whether or not to raise an error if some of the weights from the checkpoint do not have the same size - as the weights of the model (if for instance, you are instantiating a model with 10 labels from a - checkpoint with 3 labels). force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. @@ -156,33 +283,7 @@ def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], ret use this method in a firewalled environment. - """ - config_dict = cls.get_config_dict(pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs) - init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs) - - # Allow dtype to be specified on initialization - if "dtype" in unused_kwargs: - init_dict["dtype"] = unused_kwargs.pop("dtype") - - # Return model and optionally state and/or unused_kwargs - model = cls(**init_dict) - return_tuple = (model,) - - # Flax schedulers have a state, so return it. - if cls.__name__.startswith("Flax") and hasattr(model, "create_state") and getattr(model, "has_state", False): - state = model.create_state() - return_tuple += (state,) - - if return_unused_kwargs: - return return_tuple + (unused_kwargs,) - else: - return return_tuple if len(return_tuple) > 1 else model - - @classmethod - def get_config_dict( - cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs - ) -> Tuple[Dict[str, Any], Dict[str, Any]]: cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) force_download = kwargs.pop("force_download", False) resume_download = kwargs.pop("resume_download", False) @@ -278,11 +379,22 @@ def get_config_dict( except (json.JSONDecodeError, UnicodeDecodeError): raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.") + if return_unused_kwargs: + return config_dict, kwargs + return config_dict + @staticmethod + def _get_init_keys(cls): + return set(dict(inspect.signature(cls.__init__).parameters).keys()) + @classmethod def extract_init_dict(cls, config_dict, **kwargs): - expected_keys = set(dict(inspect.signature(cls.__init__).parameters).keys()) + # 0. Copy origin config dict + original_dict = {k: v for k, v in config_dict.items()} + + # 1. Retrieve expected config attributes from __init__ signature + expected_keys = cls._get_init_keys(cls) expected_keys.remove("self") # remove general kwargs if present in dict if "kwargs" in expected_keys: @@ -292,11 +404,44 @@ def extract_init_dict(cls, config_dict, **kwargs): for arg in cls._flax_internal_args: expected_keys.remove(arg) + # 2. Remove attributes that cannot be expected from expected config attributes # remove keys to be ignored if len(cls.ignore_for_config) > 0: expected_keys = expected_keys - set(cls.ignore_for_config) + + # load diffusers library to import compatible and original scheduler + diffusers_library = importlib.import_module(__name__.split(".")[0]) + + if cls.has_compatibles: + compatible_classes = [c for c in cls._get_compatibles() if not isinstance(c, DummyObject)] + else: + compatible_classes = [] + + expected_keys_comp_cls = set() + for c in compatible_classes: + expected_keys_c = cls._get_init_keys(c) + expected_keys_comp_cls = expected_keys_comp_cls.union(expected_keys_c) + expected_keys_comp_cls = expected_keys_comp_cls - cls._get_init_keys(cls) + config_dict = {k: v for k, v in config_dict.items() if k not in expected_keys_comp_cls} + + # remove attributes from orig class that cannot be expected + orig_cls_name = config_dict.pop("_class_name", cls.__name__) + if orig_cls_name != cls.__name__ and hasattr(diffusers_library, orig_cls_name): + orig_cls = getattr(diffusers_library, orig_cls_name) + unexpected_keys_from_orig = cls._get_init_keys(orig_cls) - expected_keys + config_dict = {k: v for k, v in config_dict.items() if k not in unexpected_keys_from_orig} + + # remove private attributes + config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")} + + # 3. Create keyword arguments that will be passed to __init__ from expected keyword arguments init_dict = {} for key in expected_keys: + # if config param is passed to kwarg and is present in config dict + # it should overwrite existing config dict key + if key in kwargs and key in config_dict: + config_dict[key] = kwargs.pop(key) + if key in kwargs: # overwrite key init_dict[key] = kwargs.pop(key) @@ -304,8 +449,7 @@ def extract_init_dict(cls, config_dict, **kwargs): # use value from config dict init_dict[key] = config_dict.pop(key) - config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")} - + # 4. Give nice warning if unexpected values have been passed if len(config_dict) > 0: logger.warning( f"The config attributes {config_dict} were passed to {cls.__name__}, " @@ -313,15 +457,20 @@ def extract_init_dict(cls, config_dict, **kwargs): f"{cls.config_name} configuration file." ) - unused_kwargs = {**config_dict, **kwargs} - + # 5. Give nice info if config attributes are initiliazed to default because they have not been passed passed_keys = set(init_dict.keys()) if len(expected_keys - passed_keys) > 0: logger.info( f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values." ) - return init_dict, unused_kwargs + # 6. Define unused keyword arguments + unused_kwargs = {**config_dict, **kwargs} + + # 7. Define "hidden" config parameters that were saved for compatible classes + hidden_config_dict = {k: v for k, v in original_dict.items() if k not in init_dict} + + return init_dict, unused_kwargs, hidden_config_dict @classmethod def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]): @@ -334,6 +483,12 @@ def __repr__(self): @property def config(self) -> Dict[str, Any]: + """ + Returns the config of the class as a frozen dictionary + + Returns: + `Dict[str, Any]`: Config of the class. + """ return self._internal_dict def to_json_string(self) -> str: @@ -344,6 +499,9 @@ def to_json_string(self) -> str: `str`: String containing all the attributes that make up this configuration instance in JSON format. """ config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {} + config_dict["_class_name"] = self.__class__.__name__ + config_dict["_diffusers_version"] = __version__ + return json.dumps(config_dict, indent=2, sort_keys=True) + "\n" def to_json_file(self, json_file_path: Union[str, os.PathLike]): @@ -358,38 +516,6 @@ def to_json_file(self, json_file_path: Union[str, os.PathLike]): writer.write(self.to_json_string()) -class FrozenDict(OrderedDict): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - for key, value in self.items(): - setattr(self, key, value) - - self.__frozen = True - - def __delitem__(self, *args, **kwargs): - raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.") - - def setdefault(self, *args, **kwargs): - raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.") - - def pop(self, *args, **kwargs): - raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.") - - def update(self, *args, **kwargs): - raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.") - - def __setattr__(self, name, value): - if hasattr(self, "__frozen") and self.__frozen: - raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.") - super().__setattr__(name, value) - - def __setitem__(self, name, value): - if hasattr(self, "__frozen") and self.__frozen: - raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.") - super().__setitem__(name, value) - - def register_to_config(init): r""" Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are @@ -403,7 +529,7 @@ def register_to_config(init): def inner_init(self, *args, **kwargs): # Ignore private kwargs in the init. init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")} - init(self, *args, **init_kwargs) + config_init_kwargs = {k: v for k, v in kwargs.items() if k.startswith("_")} if not isinstance(self, ConfigMixin): raise RuntimeError( f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does " @@ -428,7 +554,9 @@ def inner_init(self, *args, **kwargs): if k not in ignore and k not in new_kwargs } ) + new_kwargs = {**config_init_kwargs, **new_kwargs} getattr(self, "register_to_config")(**new_kwargs) + init(self, *args, **init_kwargs) return inner_init @@ -445,7 +573,7 @@ def init(self, *args, **kwargs): ) # Ignore private kwargs in the init. Retrieve all passed attributes - init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")} + init_kwargs = {k: v for k, v in kwargs.items()} # Retrieve default values fields = dataclasses.fields(self) diff --git a/src/diffusers/dependency_versions_table.py b/src/diffusers/dependency_versions_table.py index 8b10d70a26f7..d187b7914534 100644 --- a/src/diffusers/dependency_versions_table.py +++ b/src/diffusers/dependency_versions_table.py @@ -2,7 +2,7 @@ # 1. modify the `_deps` dict in setup.py # 2. run `make deps_table_update`` deps = { - "Pillow": "Pillow<10.0", + "Pillow": "Pillow", "accelerate": "accelerate>=0.11.0", "black": "black==22.8", "datasets": "datasets", @@ -13,14 +13,15 @@ "huggingface-hub": "huggingface-hub>=0.10.0", "importlib_metadata": "importlib_metadata", "isort": "isort>=5.5.4", - "jax": "jax>=0.2.8,!=0.3.2,<=0.3.6", - "jaxlib": "jaxlib>=0.1.65,<=0.3.6", + "jax": "jax>=0.2.8,!=0.3.2", + "jaxlib": "jaxlib>=0.1.65", "modelcards": "modelcards>=0.1.4", "numpy": "numpy", - "onnxruntime": "onnxruntime", + "parameterized": "parameterized", "pytest": "pytest", "pytest-timeout": "pytest-timeout", "pytest-xdist": "pytest-xdist", + "sentencepiece": "sentencepiece>=0.1.91,!=0.1.92", "scipy": "scipy", "regex": "regex!=2019.12.17", "requests": "requests", diff --git a/src/diffusers/experimental/README.md b/src/diffusers/experimental/README.md new file mode 100644 index 000000000000..81a9de81c737 --- /dev/null +++ b/src/diffusers/experimental/README.md @@ -0,0 +1,5 @@ +# 🧨 Diffusers Experimental + +We are adding experimental code to support novel applications and usages of the Diffusers library. +Currently, the following experiments are supported: +* Reinforcement learning via an implementation of the [Diffuser](https://arxiv.org/abs/2205.09991) model. \ No newline at end of file diff --git a/src/diffusers/experimental/__init__.py b/src/diffusers/experimental/__init__.py new file mode 100644 index 000000000000..ebc815540301 --- /dev/null +++ b/src/diffusers/experimental/__init__.py @@ -0,0 +1 @@ +from .rl import ValueGuidedRLPipeline diff --git a/src/diffusers/experimental/rl/__init__.py b/src/diffusers/experimental/rl/__init__.py new file mode 100644 index 000000000000..7b338d3173e1 --- /dev/null +++ b/src/diffusers/experimental/rl/__init__.py @@ -0,0 +1 @@ +from .value_guided_sampling import ValueGuidedRLPipeline diff --git a/src/diffusers/experimental/rl/value_guided_sampling.py b/src/diffusers/experimental/rl/value_guided_sampling.py new file mode 100644 index 000000000000..4dd935f54d60 --- /dev/null +++ b/src/diffusers/experimental/rl/value_guided_sampling.py @@ -0,0 +1,130 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import torch + +import tqdm + +from ...models.unet_1d import UNet1DModel +from ...pipeline_utils import DiffusionPipeline +from ...utils.dummy_pt_objects import DDPMScheduler + + +class ValueGuidedRLPipeline(DiffusionPipeline): + def __init__( + self, + value_function: UNet1DModel, + unet: UNet1DModel, + scheduler: DDPMScheduler, + env, + ): + super().__init__() + self.value_function = value_function + self.unet = unet + self.scheduler = scheduler + self.env = env + self.data = env.get_dataset() + self.means = dict() + for key in self.data.keys(): + try: + self.means[key] = self.data[key].mean() + except: + pass + self.stds = dict() + for key in self.data.keys(): + try: + self.stds[key] = self.data[key].std() + except: + pass + self.state_dim = env.observation_space.shape[0] + self.action_dim = env.action_space.shape[0] + + def normalize(self, x_in, key): + return (x_in - self.means[key]) / self.stds[key] + + def de_normalize(self, x_in, key): + return x_in * self.stds[key] + self.means[key] + + def to_torch(self, x_in): + if type(x_in) is dict: + return {k: self.to_torch(v) for k, v in x_in.items()} + elif torch.is_tensor(x_in): + return x_in.to(self.unet.device) + return torch.tensor(x_in, device=self.unet.device) + + def reset_x0(self, x_in, cond, act_dim): + for key, val in cond.items(): + x_in[:, key, act_dim:] = val.clone() + return x_in + + def run_diffusion(self, x, conditions, n_guide_steps, scale): + batch_size = x.shape[0] + y = None + for i in tqdm.tqdm(self.scheduler.timesteps): + # create batch of timesteps to pass into model + timesteps = torch.full((batch_size,), i, device=self.unet.device, dtype=torch.long) + for _ in range(n_guide_steps): + with torch.enable_grad(): + x.requires_grad_() + y = self.value_function(x.permute(0, 2, 1), timesteps).sample + grad = torch.autograd.grad([y.sum()], [x])[0] + + posterior_variance = self.scheduler._get_variance(i) + model_std = torch.exp(0.5 * posterior_variance) + grad = model_std * grad + grad[timesteps < 2] = 0 + x = x.detach() + x = x + scale * grad + x = self.reset_x0(x, conditions, self.action_dim) + prev_x = self.unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1) + # TODO: set prediction_type when instantiating the model + x = self.scheduler.step(prev_x, i, x, predict_epsilon=False)["prev_sample"] + + # apply conditions to the trajectory + x = self.reset_x0(x, conditions, self.action_dim) + x = self.to_torch(x) + return x, y + + def __call__(self, obs, batch_size=64, planning_horizon=32, n_guide_steps=2, scale=0.1): + # normalize the observations and create batch dimension + obs = self.normalize(obs, "observations") + obs = obs[None].repeat(batch_size, axis=0) + + conditions = {0: self.to_torch(obs)} + shape = (batch_size, planning_horizon, self.state_dim + self.action_dim) + + # generate initial noise and apply our conditions (to make the trajectories start at current state) + x1 = torch.randn(shape, device=self.unet.device) + x = self.reset_x0(x1, conditions, self.action_dim) + x = self.to_torch(x) + + # run the diffusion process + x, y = self.run_diffusion(x, conditions, n_guide_steps, scale) + + # sort output trajectories by value + sorted_idx = y.argsort(0, descending=True).squeeze() + sorted_values = x[sorted_idx] + actions = sorted_values[:, :, : self.action_dim] + actions = actions.detach().cpu().numpy() + denorm_actions = self.de_normalize(actions, key="actions") + + # select the action with the highest value + if y is not None: + selected_index = 0 + else: + # if we didn't run value guiding, select a random action + selected_index = np.random.randint(0, batch_size) + denorm_actions = denorm_actions[selected_index, 0] + return denorm_actions diff --git a/src/diffusers/hub_utils.py b/src/diffusers/hub_utils.py index c07329e36fe7..8bf0933a1dce 100644 --- a/src/diffusers/hub_utils.py +++ b/src/diffusers/hub_utils.py @@ -16,13 +16,25 @@ import os import shutil +import sys from pathlib import Path -from typing import Optional +from typing import Dict, Optional, Union +from uuid import uuid4 from huggingface_hub import HfFolder, Repository, whoami -from .pipeline_utils import DiffusionPipeline -from .utils import is_modelcards_available, logging +from . import __version__ +from .utils import ENV_VARS_TRUE_VALUES, deprecate, logging +from .utils.import_utils import ( + _flax_version, + _jax_version, + _onnxruntime_version, + _torch_version, + is_flax_available, + is_modelcards_available, + is_onnx_available, + is_torch_available, +) if is_modelcards_available(): @@ -33,6 +45,32 @@ MODEL_CARD_TEMPLATE_PATH = Path(__file__).parent / "utils" / "model_card_template.md" +SESSION_ID = uuid4().hex +DISABLE_TELEMETRY = os.getenv("DISABLE_TELEMETRY", "").upper() in ENV_VARS_TRUE_VALUES + + +def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str: + """ + Formats a user-agent string with basic info about a request. + """ + ua = f"diffusers/{__version__}; python/{sys.version.split()[0]}; session_id/{SESSION_ID}" + if DISABLE_TELEMETRY: + return ua + "; telemetry/off" + if is_torch_available(): + ua += f"; torch/{_torch_version}" + if is_flax_available(): + ua += f"; jax/{_jax_version}" + ua += f"; flax/{_flax_version}" + if is_onnx_available(): + ua += f"; onnxruntime/{_onnxruntime_version}" + # CI will set this value to True + if os.environ.get("DIFFUSERS_IS_CI", "").upper() in ENV_VARS_TRUE_VALUES: + ua += "; is_ci/true" + if isinstance(user_agent, dict): + ua += "; " + "; ".join(f"{k}/{v}" for k, v in user_agent.items()) + elif isinstance(user_agent, str): + ua += "; " + user_agent + return ua def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): @@ -53,6 +91,12 @@ def init_git_repo(args, at_init: bool = False): Whether this function is called before any training or not. If `self.args.overwrite_output_dir` is `True` and `at_init` is `True`, the path to the repo (which is `self.args.output_dir`) might be wiped out. """ + deprecation_message = ( + "Please use `huggingface_hub.Repository`. " + "See `examples/unconditional_image_generation/train_unconditional.py` for an example." + ) + deprecate("init_git_repo()", "0.10.0", deprecation_message) + if hasattr(args, "local_rank") and args.local_rank not in [-1, 0]: return hub_token = args.hub_token if hasattr(args, "hub_token") else None @@ -95,7 +139,7 @@ def init_git_repo(args, at_init: bool = False): def push_to_hub( args, - pipeline: DiffusionPipeline, + pipeline, repo: Repository, commit_message: Optional[str] = "End of training", blocking: bool = True, @@ -114,6 +158,11 @@ def push_to_hub( The url of the commit of your model in the given repository if `blocking=False`, a tuple with the url of the commit and an object to track the progress of the commit if `blocking=True` """ + deprecation_message = ( + "Please use `huggingface_hub.Repository` and `Repository.push_to_hub()`. " + "See `examples/unconditional_image_generation/train_unconditional.py` for an example." + ) + deprecate("push_to_hub()", "0.10.0", deprecation_message) if not hasattr(args, "hub_model_id") or args.hub_model_id is None: model_name = Path(args.output_dir).name diff --git a/src/diffusers/modeling_flax_utils.py b/src/diffusers/modeling_flax_utils.py index 6cb30a26f7d5..5ef100224907 100644 --- a/src/diffusers/modeling_flax_utils.py +++ b/src/diffusers/modeling_flax_utils.py @@ -105,14 +105,14 @@ def to_bf16(self, params: Union[Dict, FrozenDict], mask: Any = None): >>> from diffusers import FlaxUNet2DConditionModel >>> # load model - >>> model, params = FlaxUNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4") + >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5") >>> # By default, the model parameters will be in fp32 precision, to cast these to bfloat16 precision >>> params = model.to_bf16(params) >>> # If you don't want to cast certain parameters (for example layer norm bias and scale) >>> # then pass the mask as follows >>> from flax import traverse_util - >>> model, params = FlaxUNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4") + >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5") >>> flat_params = traverse_util.flatten_dict(params) >>> mask = { ... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale")) @@ -141,7 +141,7 @@ def to_fp32(self, params: Union[Dict, FrozenDict], mask: Any = None): >>> from diffusers import FlaxUNet2DConditionModel >>> # Download model and configuration from huggingface.co - >>> model, params = FlaxUNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4") + >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5") >>> # By default, the model params will be in fp32, to illustrate the use of this method, >>> # we'll first cast to fp16 and back to fp32 >>> params = model.to_f16(params) @@ -171,14 +171,14 @@ def to_fp16(self, params: Union[Dict, FrozenDict], mask: Any = None): >>> from diffusers import FlaxUNet2DConditionModel >>> # load model - >>> model, params = FlaxUNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4") + >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5") >>> # By default, the model params will be in fp32, to cast these to float16 >>> params = model.to_fp16(params) >>> # If you want don't want to cast certain parameters (for example layer norm bias and scale) >>> # then pass the mask as follows >>> from flax import traverse_util - >>> model, params = FlaxUNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4") + >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5") >>> flat_params = traverse_util.flatten_dict(params) >>> mask = { ... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale")) @@ -216,7 +216,7 @@ def from_pretrained( - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. Valid model ids are namespaced under a user or organization name, like - `CompVis/stable-diffusion-v1-4`. + `runwayml/stable-diffusion-v1-5`. - A path to a *directory* containing model weights saved using [`~ModelMixin.save_pretrained`], e.g., `./my_model_directory/`. dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): @@ -273,7 +273,7 @@ def from_pretrained( >>> from diffusers import FlaxUNet2DConditionModel >>> # Download model and configuration from huggingface.co and cache. - >>> model, params = FlaxUNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4") + >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5") >>> # Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable). >>> model, params = FlaxUNet2DConditionModel.from_pretrained("./test/saved_model/") ```""" @@ -482,29 +482,6 @@ def from_pretrained( " training." ) - # dictionary of key: dtypes for the model params - param_dtypes = jax.tree_map(lambda x: x.dtype, state) - # extract keys of parameters not in jnp.float32 - fp16_params = [k for k in param_dtypes if param_dtypes[k] == jnp.float16] - bf16_params = [k for k in param_dtypes if param_dtypes[k] == jnp.bfloat16] - - # raise a warning if any of the parameters are not in jnp.float32 - if len(fp16_params) > 0: - logger.warning( - f"Some of the weights of {model.__class__.__name__} were initialized in float16 precision from " - f"the model checkpoint at {pretrained_model_name_or_path}:\n{fp16_params}\n" - "You should probably UPCAST the model weights to float32 if this was not intended. " - "See [`~ModelMixin.to_fp32`] for further information on how to do this." - ) - - if len(bf16_params) > 0: - logger.warning( - f"Some of the weights of {model.__class__.__name__} were initialized in bfloat16 precision from " - f"the model checkpoint at {pretrained_model_name_or_path}:\n{bf16_params}\n" - "You should probably UPCAST the model weights to float32 if this was not intended. " - "See [`~ModelMixin.to_fp32`] for further information on how to do this." - ) - return model, unflatten_dict(state) def save_pretrained( diff --git a/src/diffusers/modeling_utils.py b/src/diffusers/modeling_utils.py index c62caf80286e..8cb0acf52f42 100644 --- a/src/diffusers/modeling_utils.py +++ b/src/diffusers/modeling_utils.py @@ -21,18 +21,37 @@ import torch from torch import Tensor, device -from diffusers.utils import is_accelerate_available from huggingface_hub import hf_hub_download from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError from requests import HTTPError from . import __version__ -from .utils import CONFIG_NAME, DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, WEIGHTS_NAME, logging +from .utils import ( + CONFIG_NAME, + DIFFUSERS_CACHE, + HUGGINGFACE_CO_RESOLVE_ENDPOINT, + WEIGHTS_NAME, + is_accelerate_available, + is_torch_version, + logging, +) logger = logging.get_logger(__name__) +if is_torch_version(">=", "1.9.0"): + _LOW_CPU_MEM_USAGE_DEFAULT = True +else: + _LOW_CPU_MEM_USAGE_DEFAULT = False + + +if is_accelerate_available(): + import accelerate + from accelerate.utils import set_module_tensor_to_device + from accelerate.utils.versions import is_torch_version + + def get_parameter_device(parameter: torch.nn.Module): try: return next(parameter.parameters()).device @@ -268,6 +287,19 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P Mirror source to accelerate downloads in China. If you are from China and have an accessibility problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. Please refer to the mirror site for more information. + device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): + A map that specifies where each submodule should go. It doesn't need to be refined to each + parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the + same device. + + To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For + more information about each option see [designing a device + map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). + low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): + Speed up model loading by not initializing the weights and only loading the pre-trained weights. This + also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the + model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch, + setting this argument to `True` will raise an error. @@ -296,6 +328,41 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P torch_dtype = kwargs.pop("torch_dtype", None) subfolder = kwargs.pop("subfolder", None) device_map = kwargs.pop("device_map", None) + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) + + if low_cpu_mem_usage and not is_accelerate_available(): + low_cpu_mem_usage = False + logger.warning( + "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" + " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" + " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" + " install accelerate\n```\n." + ) + + if device_map is not None and not is_accelerate_available(): + raise NotImplementedError( + "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set" + " `device_map=None`. You can install accelerate with `pip install accelerate`." + ) + + # Check if we can handle device_map and dispatching the weights + if device_map is not None and not is_torch_version(">=", "1.9.0"): + raise NotImplementedError( + "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set" + " `device_map=None`." + ) + + if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): + raise NotImplementedError( + "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" + " `low_cpu_mem_usage=False`." + ) + + if low_cpu_mem_usage is False and device_map is not None: + raise ValueError( + f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and" + " dispatching. Please make sure to set `low_cpu_mem_usage=True`." + ) user_agent = { "diffusers": __version__, @@ -378,14 +445,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # restore default dtype - if device_map == "auto": - if is_accelerate_available(): - import accelerate - else: - raise ImportError("Please install accelerate via `pip install accelerate`") - + if low_cpu_mem_usage: + # Instantiate model with empty weights with accelerate.init_empty_weights(): - model, unused_kwargs = cls.from_config( + config, unused_kwargs = cls.load_config( config_path, cache_dir=cache_dir, return_unused_kwargs=True, @@ -399,8 +462,19 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P device_map=device_map, **kwargs, ) - - accelerate.load_checkpoint_and_dispatch(model, model_file, device_map) + model = cls.from_config(config, **unused_kwargs) + + # if device_map is Non,e load the state dict on move the params from meta device to the cpu + if device_map is None: + param_device = "cpu" + state_dict = load_state_dict(model_file) + # move the parms from meta device to cpu + for param_name, param in state_dict.items(): + set_module_tensor_to_device(model, param_name, param_device, value=param) + else: # else let accelerate handle loading and dispatching. + # Load weights and dispatch according to the device_map + # by deafult the device_map is None and the weights are loaded on the CPU + accelerate.load_checkpoint_and_dispatch(model, model_file, device_map) loading_info = { "missing_keys": [], @@ -409,7 +483,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P "error_msgs": [], } else: - model, unused_kwargs = cls.from_config( + config, unused_kwargs = cls.load_config( config_path, cache_dir=cache_dir, return_unused_kwargs=True, @@ -423,6 +497,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P device_map=device_map, **kwargs, ) + model = cls.from_config(config, **unused_kwargs) state_dict = load_state_dict(model_file) model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model( diff --git a/src/diffusers/models/README.md b/src/diffusers/models/README.md index e786fe518f03..80fe0bc38140 100644 --- a/src/diffusers/models/README.md +++ b/src/diffusers/models/README.md @@ -1,11 +1,3 @@ # Models -- Models: Neural network that models $p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t)$ (see image below) and is trained end-to-end to denoise a noisy input to an image. Examples: UNet, Conditioned UNet, 3D UNet, Transformer UNet - -## API - -TODO(Suraj, Patrick) - -## Examples - -TODO(Suraj, Patrick) +For more detail on the models, please refer to the [docs](https://huggingface.co/docs/diffusers/api/models). \ No newline at end of file diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 1242ad6fca7f..5b101d169148 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -16,6 +16,8 @@ if is_torch_available(): + from .attention import Transformer2DModel + from .unet_1d import UNet1DModel from .unet_2d import UNet2DModel from .unet_2d_condition import UNet2DConditionModel from .vae import AutoencoderKL, VQModel diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 41830ed3917c..2bbc1a9c41de 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -12,12 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. import math +import warnings +from dataclasses import dataclass from typing import Optional import torch import torch.nn.functional as F from torch import nn +from ..configuration_utils import ConfigMixin, register_to_config +from ..modeling_utils import ModelMixin +from ..models.embeddings import ImagePositionalEmbeddings +from ..utils import BaseOutput +from ..utils.import_utils import is_xformers_available + from einops import rearrange # TODO(danfu): upstream this to FlashAttention @@ -41,6 +49,224 @@ def check_cuda(): except ImportError: flash_attn_installed = False +@dataclass +class Transformer2DModelOutput(BaseOutput): + """ + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): + Hidden states conditioned on `encoder_hidden_states` input. If discrete, returns probability distributions + for the unnoised latent pixels. + """ + + sample: torch.FloatTensor + + +if is_xformers_available(): + import xformers + import xformers.ops +else: + xformers = None + + +class Transformer2DModel(ModelMixin, ConfigMixin): + """ + Transformer model for image-like data. Takes either discrete (classes of vector embeddings) or continuous (actual + embeddings) inputs. + When input is continuous: First, project the input (aka embedding) and reshape to b, t, d. Then apply standard + transformer action. Finally, reshape to image. + When input is discrete: First, input (classes of latent pixels) is converted to embeddings and has positional + embeddings applied, see `ImagePositionalEmbeddings`. Then apply standard transformer action. Finally, predict + classes of unnoised image. + Note that it is assumed one of the input classes is the masked latent pixel. The predicted classes of the unnoised + image do not contain a prediction for the masked pixel as the unnoised image cannot be masked. + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + Pass if the input is continuous. The number of channels in the input and output. + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of context dimensions to use. + sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. + Note that this is fixed at training time as it is used for learning a number of position embeddings. See + `ImagePositionalEmbeddings`. + num_vector_embeds (`int`, *optional*): + Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels. + Includes the class for the masked latent pixel. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. + The number of diffusion steps used during training. Note that this is fixed at training time as it is used + to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for + up to but not more than steps than `num_embeds_ada_norm`. + attention_bias (`bool`, *optional*): + Configure if the TransformerBlocks' attention should contain a bias parameter. + """ + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + num_vector_embeds: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + ): + super().__init__() + self.use_linear_projection = use_linear_projection + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + # 1. Transformer2DModel can process both standard continous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` + # Define whether input is continuous or discrete depending on configuration + self.is_input_continuous = in_channels is not None + self.is_input_vectorized = num_vector_embeds is not None + + if self.is_input_continuous and self.is_input_vectorized: + raise ValueError( + f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make" + " sure that either `in_channels` or `num_vector_embeds` is None." + ) + elif not self.is_input_continuous and not self.is_input_vectorized: + raise ValueError( + f"Has to define either `in_channels`: {in_channels} or `num_vector_embeds`: {num_vector_embeds}. Make" + " sure that either `in_channels` or `num_vector_embeds` is not None." + ) + + # 2. Define input layers + if self.is_input_continuous: + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + if use_linear_projection: + self.proj_in = nn.Linear(in_channels, inner_dim) + else: + self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + elif self.is_input_vectorized: + assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" + assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed" + + self.height = sample_size + self.width = sample_size + self.num_vector_embeds = num_vector_embeds + self.num_latent_pixels = self.height * self.width + + self.latent_image_embedding = ImagePositionalEmbeddings( + num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width + ) + + # 3. Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + ) + for d in range(num_layers) + ] + ) + + # 4. Define output layers + if self.is_input_continuous: + if use_linear_projection: + self.proj_out = nn.Linear(in_channels, inner_dim) + else: + self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + elif self.is_input_vectorized: + self.norm_out = nn.LayerNorm(inner_dim) + self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) + + def _set_attention_slice(self, slice_size): + for block in self.transformer_blocks: + block._set_attention_slice(slice_size) + + def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True): + """ + Args: + hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. + When continous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input + hidden_states + encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, context dim)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.long`, *optional*): + Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + Returns: + [`~models.attention.Transformer2DModelOutput`] or `tuple`: [`~models.attention.Transformer2DModelOutput`] + if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample + tensor. + """ + # 1. Input + if self.is_input_continuous: + batch, channel, height, weight = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + if not self.use_linear_projection: + hidden_states = self.proj_in(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + else: + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + hidden_states = self.proj_in(hidden_states) + elif self.is_input_vectorized: + hidden_states = self.latent_image_embedding(hidden_states) + + # 2. Blocks + for block in self.transformer_blocks: + hidden_states = block(hidden_states, context=encoder_hidden_states, timestep=timestep) + + # 3. Output + if self.is_input_continuous: + if not self.use_linear_projection: + hidden_states = ( + hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() + ) + hidden_states = self.proj_out(hidden_states) + else: + hidden_states = self.proj_out(hidden_states) + hidden_states = ( + hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() + ) + + output = hidden_states + residual + elif self.is_input_vectorized: + hidden_states = self.norm_out(hidden_states) + logits = self.out(hidden_states) + # (batch, self.num_vector_embeds - 1, self.num_latent_pixels) + logits = logits.permute(0, 2, 1) + + # log(p(x_0)) + output = F.log_softmax(logits.double(), dim=1).float() + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) + + def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): + for block in self.transformer_blocks: + block._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) + class AttentionBlock(nn.Module): """ @@ -50,19 +276,19 @@ class AttentionBlock(nn.Module): Uses three q, k, v linear layers to compute attention. Parameters: - channels (:obj:`int`): The number of channels in the input and output. - num_head_channels (:obj:`int`, *optional*): + channels (`int`): The number of channels in the input and output. + num_head_channels (`int`, *optional*): The number of channels in each head. If None, then `num_heads` = 1. - num_groups (:obj:`int`, *optional*, defaults to 32): The number of groups to use for group norm. - rescale_output_factor (:obj:`float`, *optional*, defaults to 1.0): The factor to rescale the output by. - eps (:obj:`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm. + norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for group norm. + rescale_output_factor (`float`, *optional*, defaults to 1.0): The factor to rescale the output by. + eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm. """ def __init__( self, channels: int, num_head_channels: Optional[int] = None, - num_groups: int = 32, + norm_num_groups: int = 32, rescale_output_factor: float = 1.0, eps: float = 1e-5, ): @@ -71,7 +297,7 @@ def __init__( self.num_heads = channels // num_head_channels if num_head_channels is not None else 1 self.num_head_size = num_head_channels - self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=eps, affine=True) + self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=norm_num_groups, eps=eps, affine=True) # define q,k,v as linear layers self.query = nn.Linear(channels, channels) @@ -101,21 +327,51 @@ def forward(self, hidden_states): key_proj = self.key(hidden_states) value_proj = self.value(hidden_states) - # transpose - query_states = self.transpose_for_scores(query_proj) - key_states = self.transpose_for_scores(key_proj) - value_states = self.transpose_for_scores(value_proj) + scale = 1 / math.sqrt(self.channels / self.num_heads) # get scores - scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads)) - attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale) # TODO: use baddmm + if self.num_heads > 1: + query_states = self.transpose_for_scores(query_proj) + key_states = self.transpose_for_scores(key_proj) + value_states = self.transpose_for_scores(value_proj) + + # TODO: is there a way to perform batched matmul (e.g. baddbmm) on 4D tensors? + # or reformulate this into a 3D problem? + # TODO: measure whether on MPS device it would be faster to do this matmul via einsum + # as some matmuls can be 1.94x slower than an equivalent einsum on MPS + # https://gist.github.com/Birch-san/cba16789ec27bb20996a4b4831b13ce0 + attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) * scale + else: + query_states, key_states, value_states = query_proj, key_proj, value_proj + + attention_scores = torch.baddbmm( + torch.empty( + query_states.shape[0], + query_states.shape[1], + key_states.shape[1], + dtype=query_states.dtype, + device=query_states.device, + ), + query_states, + key_states.transpose(-1, -2), + beta=0, + alpha=scale, + ) + attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype) # compute attention output - hidden_states = torch.matmul(attention_probs, value_states) - - hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous() - new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,) - hidden_states = hidden_states.view(new_hidden_states_shape) + if self.num_heads > 1: + # TODO: is there a way to perform batched matmul (e.g. bmm) on 4D tensors? + # or reformulate this into a 3D problem? + # TODO: measure whether on MPS device it would be faster to do this matmul via einsum + # as some matmuls can be 1.94x slower than an equivalent einsum on MPS + # https://gist.github.com/Birch-san/cba16789ec27bb20996a4b4831b13ce0 + hidden_states = torch.matmul(attention_probs, value_states) + hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous() + new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,) + hidden_states = hidden_states.view(new_hidden_states_shape) + else: + hidden_states = torch.bmm(attention_probs, value_states) # compute next hidden_states hidden_states = self.proj_attn(hidden_states) @@ -126,113 +382,125 @@ def forward(self, hidden_states): return hidden_states -class SpatialTransformer(nn.Module): - """ - Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply - standard transformer action. Finally, reshape to image. - - Parameters: - in_channels (:obj:`int`): The number of channels in the input and output. - n_heads (:obj:`int`): The number of heads to use for multi-head attention. - d_head (:obj:`int`): The number of channels in each head. - depth (:obj:`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. - dropout (:obj:`float`, *optional*, defaults to 0.1): The dropout probability to use. - context_dim (:obj:`int`, *optional*): The number of context dimensions to use. - """ - - def __init__( - self, - in_channels: int, - n_heads: int, - d_head: int, - depth: int = 1, - dropout: float = 0.0, - num_groups: int = 32, - context_dim: Optional[int] = None, - ): - super().__init__() - self.n_heads = n_heads - self.d_head = d_head - self.in_channels = in_channels - inner_dim = n_heads * d_head - self.norm = torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) - - self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) - - self.transformer_blocks = nn.ModuleList( - [ - BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim) - for d in range(depth) - ] - ) - - self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) - - def _set_attention_slice(self, slice_size): - for block in self.transformer_blocks: - block._set_attention_slice(slice_size) - - def forward(self, hidden_states, context=None): - # note: if no context is given, cross-attention defaults to self-attention - batch, channel, height, weight = hidden_states.shape - residual = hidden_states - hidden_states = self.norm(hidden_states) - hidden_states = self.proj_in(hidden_states) - inner_dim = hidden_states.shape[1] - hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) - for block in self.transformer_blocks: - hidden_states = block(hidden_states, context=context) - hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2) - hidden_states = self.proj_out(hidden_states) - return hidden_states + residual - - class BasicTransformerBlock(nn.Module): r""" A basic Transformer block. Parameters: - dim (:obj:`int`): The number of channels in the input and output. - n_heads (:obj:`int`): The number of heads to use for multi-head attention. - d_head (:obj:`int`): The number of channels in each head. - dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use. - context_dim (:obj:`int`, *optional*): The size of the context vector for cross attention. - gated_ff (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use a gated feed-forward network. - checkpoint (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use checkpointing. + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the context vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. """ def __init__( self, dim: int, - n_heads: int, - d_head: int, + num_attention_heads: int, + attention_head_dim: int, dropout=0.0, - context_dim: Optional[int] = None, - gated_ff: bool = True, - checkpoint: bool = True, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, ): super().__init__() + self.only_cross_attention = only_cross_attention self.attn1 = CrossAttention( - query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, ) # is a self-attention - self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) self.attn2 = CrossAttention( - query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, ) # is self-attn if context is none - self.norm1 = nn.LayerNorm(dim) - self.norm2 = nn.LayerNorm(dim) + + # layer norms + self.use_ada_layer_norm = num_embeds_ada_norm is not None + if self.use_ada_layer_norm: + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) + self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) + else: + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) self.norm3 = nn.LayerNorm(dim) - self.checkpoint = checkpoint + + # if xformers is installed try to use memory_efficient_attention by default + if is_xformers_available(): + try: + self._set_use_memory_efficient_attention_xformers(True) + except Exception as e: + warnings.warn( + "Could not enable memory efficient attention. Make sure xformers is installed" + f" correctly and a GPU is available: {e}" + ) def _set_attention_slice(self, slice_size): self.attn1._slice_size = slice_size self.attn2._slice_size = slice_size - def forward(self, hidden_states, context=None): - hidden_states = hidden_states.contiguous() if hidden_states.device.type == "mps" else hidden_states - hidden_states = self.attn1(self.norm1(hidden_states)) + hidden_states - hidden_states = self.attn2(self.norm2(hidden_states), context=context) + hidden_states + def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): + if not is_xformers_available(): + print("Here is how to install it") + raise ModuleNotFoundError( + "Refer to https://github.com/facebookresearch/xformers for more information on how to install" + " xformers", + name="xformers", + ) + elif not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only" + " available for GPU " + ) + else: + try: + # Make sure we can run the memory efficient attention + _ = xformers.ops.memory_efficient_attention( + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + ) + except Exception as e: + raise e + self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers + self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers + + def forward(self, hidden_states, context=None, timestep=None): + # 1. Self-Attention + norm_hidden_states = ( + self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states) + ) + + if self.only_cross_attention: + hidden_states = self.attn1(norm_hidden_states, context) + hidden_states + else: + hidden_states = self.attn1(norm_hidden_states) + hidden_states + + # 2. Cross-Attention + norm_hidden_states = ( + self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) + ) + hidden_states = self.attn2(norm_hidden_states, context=context) + hidden_states + + # 3. Feed-forward hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states + return hidden_states @@ -241,20 +509,28 @@ class CrossAttention(nn.Module): A cross attention layer. Parameters: - query_dim (:obj:`int`): The number of channels in the query. - context_dim (:obj:`int`, *optional*): + query_dim (`int`): The number of channels in the query. + cross_attention_dim (`int`, *optional*): The number of channels in the context. If not given, defaults to `query_dim`. - heads (:obj:`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. - dim_head (:obj:`int`, *optional*, defaults to 64): The number of channels in each head. - dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use. + heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. + dim_head (`int`, *optional*, defaults to 64): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + bias (`bool`, *optional*, defaults to False): + Set to `True` for the query, key, and value linear layers to contain a bias parameter. """ def __init__( - self, query_dim: int, context_dim: Optional[int] = None, heads: int = 8, dim_head: int = 64, dropout: int = 0.0 + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias=False, ): super().__init__() inner_dim = dim_head * heads - context_dim = context_dim if context_dim is not None else query_dim + cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim self.scale = dim_head**-0.5 self.heads = heads @@ -262,19 +538,22 @@ def __init__( # is split across the batch axis to save memory # You can set slice_size with `set_attention_slice` self._slice_size = None + self._use_memory_efficient_attention_xformers = False self.dim_head = dim_head - self.context_dim = context_dim + self.context_dim = cross_attention_dim self.query_dim = query_dim if self.context_dim == self.query_dim and self.dim_head <= 128 and (self.dim_head % 8) == 0 and flash_attn_installed: self.flash_attn = FlashAttention(self.scale) self.to_q = nn.Linear(query_dim, inner_dim, bias=False) - self.to_k = nn.Linear(context_dim, inner_dim, bias=False) - self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=False) + self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=False) - self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(inner_dim, query_dim)) + self.to_out.append(nn.Dropout(dropout)) def reshape_heads_to_batch_dim(self, tensor): batch_size, seq_len, dim = tensor.shape @@ -303,13 +582,21 @@ def forward(self, hidden_states, context=None, mask=None): # TODO(PVP) - mask is currently never used. Remember to re-implement when used # attention, what we cannot get enough of - - if self._slice_size is None or query.shape[0] // self._slice_size == 1: - hidden_states = self._attention(query, key, value) + if self._use_memory_efficient_attention_xformers: + hidden_states = self._memory_efficient_attention_xformers(query, key, value) + # Some versions of xformers return output in fp32, cast it back to the dtype of the input + hidden_states = hidden_states.to(query.dtype) else: - hidden_states = self._sliced_attention(query, key, value, sequence_length, dim) - - return self.to_out(hidden_states) + if self._slice_size is None or query.shape[0] // self._slice_size == 1: + hidden_states = self._attention(query, key, value) + else: + hidden_states = self._sliced_attention(query, key, value, sequence_length, dim) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + # dropout + hidden_states = self.to_out[1](hidden_states) + return hidden_states def _attention(self, query, key, value): batch_size = query.shape[0] @@ -360,11 +647,15 @@ def _sliced_attention(self, query, key, value, sequence_length, dim): for i in range(hidden_states.shape[0] // slice_size): start_idx = i * slice_size end_idx = (i + 1) * slice_size - attn_slice = ( - torch.matmul(query[start_idx:end_idx], key[start_idx:end_idx].transpose(1, 2)) * self.scale - ) # TODO: use baddbmm for better performance + attn_slice = torch.baddbmm( + torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), + query[start_idx:end_idx], + key[start_idx:end_idx].transpose(-1, -2), + beta=0, + alpha=self.scale, + ) attn_slice = attn_slice.softmax(dim=-1) - attn_slice = torch.matmul(attn_slice, value[start_idx:end_idx]) + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) hidden_states[start_idx:end_idx] = attn_slice @@ -372,31 +663,56 @@ def _sliced_attention(self, query, key, value, sequence_length, dim): hidden_states = self.reshape_batch_dim_to_heads(hidden_states) return hidden_states + def _memory_efficient_attention_xformers(self, query, key, value): + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=None) + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + class FeedForward(nn.Module): r""" A feed-forward layer. Parameters: - dim (:obj:`int`): The number of channels in the input. - dim_out (:obj:`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. - mult (:obj:`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. - glu (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use GLU activation. - dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use. + dim (`int`): The number of channels in the input. + dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. """ def __init__( - self, dim: int, dim_out: Optional[int] = None, mult: int = 4, glu: bool = False, dropout: float = 0.0 + self, + dim: int, + dim_out: Optional[int] = None, + mult: int = 4, + dropout: float = 0.0, + activation_fn: str = "geglu", ): super().__init__() inner_dim = int(dim * mult) dim_out = dim_out if dim_out is not None else dim - project_in = GEGLU(dim, inner_dim) - self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)) + if activation_fn == "geglu": + geglu = GEGLU(dim, inner_dim) + elif activation_fn == "geglu-approximate": + geglu = ApproximateGELU(dim, inner_dim) + + self.net = nn.ModuleList([]) + # project in + self.net.append(geglu) + # project dropout + self.net.append(nn.Dropout(dropout)) + # project out + self.net.append(nn.Linear(inner_dim, dim_out)) def forward(self, hidden_states): - return self.net(hidden_states) + for module in self.net: + hidden_states = module(hidden_states) + return hidden_states # feedforward @@ -405,14 +721,181 @@ class GEGLU(nn.Module): A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202. Parameters: - dim_in (:obj:`int`): The number of channels in the input. - dim_out (:obj:`int`): The number of channels in the output. + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. """ def __init__(self, dim_in: int, dim_out: int): super().__init__() self.proj = nn.Linear(dim_in, dim_out * 2) + def gelu(self, gate): + if gate.device.type != "mps": + return F.gelu(gate) + # mps: gelu is not implemented for float16 + return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) + def forward(self, hidden_states): hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1) - return hidden_states * F.gelu(gate) + return hidden_states * self.gelu(gate) + + +class ApproximateGELU(nn.Module): + """ + The approximate form of Gaussian Error Linear Unit (GELU) + + For more details, see section 2: https://arxiv.org/abs/1606.08415 + """ + + def __init__(self, dim_in: int, dim_out: int): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out) + + def forward(self, x): + x = self.proj(x) + return x * torch.sigmoid(1.702 * x) + + +class AdaLayerNorm(nn.Module): + """ + Norm layer modified to incorporate timestep embeddings. + """ + + def __init__(self, embedding_dim, num_embeddings): + super().__init__() + self.emb = nn.Embedding(num_embeddings, embedding_dim) + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, embedding_dim * 2) + self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False) + + def forward(self, x, timestep): + emb = self.linear(self.silu(self.emb(timestep))) + scale, shift = torch.chunk(emb, 2) + x = self.norm(x) * (1 + scale) + shift + return x + + +class DualTransformer2DModel(nn.Module): + """ + Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + Pass if the input is continuous. The number of channels in the input and output. + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of context dimensions to use. + sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. + Note that this is fixed at training time as it is used for learning a number of position embeddings. See + `ImagePositionalEmbeddings`. + num_vector_embeds (`int`, *optional*): + Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels. + Includes the class for the masked latent pixel. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. + The number of diffusion steps used during training. Note that this is fixed at training time as it is used + to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for + up to but not more than steps than `num_embeds_ada_norm`. + attention_bias (`bool`, *optional*): + Configure if the TransformerBlocks' attention should contain a bias parameter. + """ + + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + num_vector_embeds: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + ): + super().__init__() + self.transformers = nn.ModuleList( + [ + Transformer2DModel( + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + in_channels=in_channels, + num_layers=num_layers, + dropout=dropout, + norm_num_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attention_bias=attention_bias, + sample_size=sample_size, + num_vector_embeds=num_vector_embeds, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + ) + for _ in range(2) + ] + ) + + # Variables that can be set by a pipeline: + + # The ratio of transformer1 to transformer2's output states to be combined during inference + self.mix_ratio = 0.5 + + # The shape of `encoder_hidden_states` is expected to be + # `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)` + self.condition_lengths = [77, 257] + + # Which transformer to use to encode which condition. + # E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])` + self.transformer_index_for_condition = [1, 0] + + def forward(self, hidden_states, encoder_hidden_states, timestep=None, return_dict: bool = True): + """ + Args: + hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. + When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input + hidden_states + encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, context dim)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.long`, *optional*): + Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + + Returns: + [`~models.attention.Transformer2DModelOutput`] or `tuple`: [`~models.attention.Transformer2DModelOutput`] + if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample + tensor. + """ + input_states = hidden_states + + encoded_states = [] + tokens_start = 0 + for i in range(2): + # for each of the two transformers, pass the corresponding condition tokens + condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]] + transformer_index = self.transformer_index_for_condition[i] + encoded_state = self.transformers[transformer_index](input_states, condition_state, timestep, return_dict)[ + 0 + ] + encoded_states.append(encoded_state - input_states) + tokens_start += self.condition_lengths[i] + + output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio) + output_states = output_states + input_states + + if not return_dict: + return (output_states,) + + return Transformer2DModelOutput(sample=output_states) + + def _set_attention_slice(self, slice_size): + for transformer in self.transformers: + transformer._set_attention_slice(slice_size) + + def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): + for transformer in self.transformers: + transformer._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) diff --git a/src/diffusers/models/attention_flax.py b/src/diffusers/models/attention_flax.py index 1745265b91e1..1b8609474750 100644 --- a/src/diffusers/models/attention_flax.py +++ b/src/diffusers/models/attention_flax.py @@ -142,7 +142,7 @@ def __call__(self, hidden_states, context, deterministic=True): return hidden_states -class FlaxSpatialTransformer(nn.Module): +class FlaxTransformer2DModel(nn.Module): r""" A Spatial Transformer layer with Gated Linear Unit (GLU) activation function as described in: https://arxiv.org/pdf/1506.02025.pdf diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 06b814e2bbcd..0221d891f171 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -62,14 +62,21 @@ def get_timestep_embedding( class TimestepEmbedding(nn.Module): - def __init__(self, channel: int, time_embed_dim: int, act_fn: str = "silu"): + def __init__(self, in_channels: int, time_embed_dim: int, act_fn: str = "silu", out_dim: int = None): super().__init__() - self.linear_1 = nn.Linear(channel, time_embed_dim) + self.linear_1 = nn.Linear(in_channels, time_embed_dim) self.act = None if act_fn == "silu": self.act = nn.SiLU() - self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim) + elif act_fn == "mish": + self.act = nn.Mish() + + if out_dim is not None: + time_embed_dim_out = out_dim + else: + time_embed_dim_out = time_embed_dim + self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out) def forward(self, sample): sample = self.linear_1(sample) @@ -101,17 +108,93 @@ def forward(self, timesteps): class GaussianFourierProjection(nn.Module): """Gaussian Fourier embeddings for noise levels.""" - def __init__(self, embedding_size: int = 256, scale: float = 1.0): + def __init__( + self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False + ): super().__init__() self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) + self.log = log + self.flip_sin_to_cos = flip_sin_to_cos - # to delete later - self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) + if set_W_to_weight: + # to delete later + self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) - self.weight = self.W + self.weight = self.W def forward(self, x): - x = torch.log(x) + if self.log: + x = torch.log(x) + x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi - out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) + + if self.flip_sin_to_cos: + out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1) + else: + out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) return out + + +class ImagePositionalEmbeddings(nn.Module): + """ + Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the + height and width of the latent space. + + For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092 + + For VQ-diffusion: + + Output vector embeddings are used as input for the transformer. + + Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE. + + Args: + num_embed (`int`): + Number of embeddings for the latent pixels embeddings. + height (`int`): + Height of the latent image i.e. the number of height embeddings. + width (`int`): + Width of the latent image i.e. the number of width embeddings. + embed_dim (`int`): + Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings. + """ + + def __init__( + self, + num_embed: int, + height: int, + width: int, + embed_dim: int, + ): + super().__init__() + + self.height = height + self.width = width + self.num_embed = num_embed + self.embed_dim = embed_dim + + self.emb = nn.Embedding(self.num_embed, embed_dim) + self.height_emb = nn.Embedding(self.height, embed_dim) + self.width_emb = nn.Embedding(self.width, embed_dim) + + def forward(self, index): + emb = self.emb(index) + + height_emb = self.height_emb(torch.arange(self.height, device=index.device).view(1, self.height)) + + # 1 x H x D -> 1 x H x 1 x D + height_emb = height_emb.unsqueeze(2) + + width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width)) + + # 1 x W x D -> 1 x 1 x W x D + width_emb = width_emb.unsqueeze(1) + + pos_emb = height_emb + width_emb + + # 1 x H x W x D -> 1 x L xD + pos_emb = pos_emb.view(1, self.height * self.width, -1) + + emb = emb + pos_emb[:, : emb.shape[1], :] + + return emb diff --git a/src/diffusers/models/embeddings_flax.py b/src/diffusers/models/embeddings_flax.py index e2d607499c79..bf7d54b82ec2 100644 --- a/src/diffusers/models/embeddings_flax.py +++ b/src/diffusers/models/embeddings_flax.py @@ -17,23 +17,41 @@ import jax.numpy as jnp -# This is like models.embeddings.get_timestep_embedding (PyTorch) but -# less general (only handles the case we currently need). -def get_sinusoidal_embeddings(timesteps, embedding_dim, freq_shift: float = 1): +def get_sinusoidal_embeddings( + timesteps: jnp.ndarray, + embedding_dim: int, + freq_shift: float = 1, + min_timescale: float = 1, + max_timescale: float = 1.0e4, + flip_sin_to_cos: bool = False, + scale: float = 1.0, +) -> jnp.ndarray: + """Returns the positional encoding (same as Tensor2Tensor). + Args: + timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + embedding_dim: The number of output channels. + min_timescale: The smallest time unit (should probably be 0.0). + max_timescale: The largest time unit. + Returns: + a Tensor of timing signals [N, num_channels] """ - This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. + assert timesteps.ndim == 1, "Timesteps should be a 1d-array" + assert embedding_dim % 2 == 0, f"Embedding dimension {embedding_dim} should be even" + num_timescales = float(embedding_dim // 2) + log_timescale_increment = math.log(max_timescale / min_timescale) / (num_timescales - freq_shift) + inv_timescales = min_timescale * jnp.exp(jnp.arange(num_timescales, dtype=jnp.float32) * -log_timescale_increment) + emb = jnp.expand_dims(timesteps, 1) * jnp.expand_dims(inv_timescales, 0) - :param timesteps: a 1-D tensor of N indices, one per batch element. - These may be fractional. - :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the - embeddings. :return: an [N x dim] tensor of positional embeddings. - """ - half_dim = embedding_dim // 2 - emb = math.log(10000) / (half_dim - freq_shift) - emb = jnp.exp(jnp.arange(half_dim) * -emb) - emb = timesteps[:, None] * emb[None, :] - emb = jnp.concatenate([jnp.cos(emb), jnp.sin(emb)], -1) - return emb + # scale embeddings + scaled_time = scale * emb + + if flip_sin_to_cos: + signal = jnp.concatenate([jnp.cos(scaled_time), jnp.sin(scaled_time)], axis=1) + else: + signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=1) + signal = jnp.reshape(signal, [jnp.shape(timesteps)[0], embedding_dim]) + return signal class FlaxTimestepEmbedding(nn.Module): @@ -70,4 +88,6 @@ class FlaxTimesteps(nn.Module): @nn.compact def __call__(self, timesteps): - return get_sinusoidal_embeddings(timesteps, self.dim, freq_shift=self.freq_shift) + return get_sinusoidal_embeddings( + timesteps, embedding_dim=self.dim, freq_shift=self.freq_shift, flip_sin_to_cos=True + ) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index d4cb367ebc0b..52d056ae96fb 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -5,6 +5,75 @@ import torch.nn.functional as F +class Upsample1D(nn.Module): + """ + An upsampling layer with an optional convolution. + + Parameters: + channels: channels in the inputs and outputs. + use_conv: a bool determining if a convolution is applied. + use_conv_transpose: + out_channels: + """ + + def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_conv_transpose = use_conv_transpose + self.name = name + + self.conv = None + if use_conv_transpose: + self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1) + elif use_conv: + self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.use_conv_transpose: + return self.conv(x) + + x = F.interpolate(x, scale_factor=2.0, mode="nearest") + + if self.use_conv: + x = self.conv(x) + + return x + + +class Downsample1D(nn.Module): + """ + A downsampling layer with an optional convolution. + + Parameters: + channels: channels in the inputs and outputs. + use_conv: a bool determining if a convolution is applied. + out_channels: + padding: + """ + + def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.padding = padding + stride = 2 + self.name = name + + if use_conv: + self.conv = nn.Conv1d(self.channels, self.out_channels, 3, stride=stride, padding=padding) + else: + assert self.channels == self.out_channels + self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.conv(x) + + class Upsample2D(nn.Module): """ An upsampling layer with an optional convolution. @@ -12,7 +81,8 @@ class Upsample2D(nn.Module): Parameters: channels: channels in the inputs and outputs. use_conv: a bool determining if a convolution is applied. - dims: determines if the signal is 1D, 2D, or 3D. If 3D, then upsampling occurs in the inner-two dimensions. + use_conv_transpose: + out_channels: """ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): @@ -48,6 +118,10 @@ def forward(self, hidden_states, output_size=None): if dtype == torch.bfloat16: hidden_states = hidden_states.to(torch.float32) + # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 + if hidden_states.shape[0] >= 64: + hidden_states = hidden_states.contiguous() + # if `output_size` is passed we force the interpolation output # size and do not make use of `scale_factor=2` if output_size is None: @@ -76,7 +150,8 @@ class Downsample2D(nn.Module): Parameters: channels: channels in the inputs and outputs. use_conv: a bool determining if a convolution is applied. - dims: determines if the signal is 1D, 2D, or 3D. If 3D, then downsampling occurs in the inner-two dimensions. + out_channels: + padding: """ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): @@ -376,6 +451,10 @@ def forward(self, input_tensor, temb): hidden_states = self.nonlinearity(hidden_states) if self.upsample is not None: + # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 + if hidden_states.shape[0] >= 64: + input_tensor = input_tensor.contiguous() + hidden_states = hidden_states.contiguous() input_tensor = self.upsample(input_tensor) hidden_states = self.upsample(hidden_states) elif self.downsample is not None: @@ -407,6 +486,69 @@ def forward(self, hidden_states): return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states)) +# unet_rl.py +def rearrange_dims(tensor): + if len(tensor.shape) == 2: + return tensor[:, :, None] + if len(tensor.shape) == 3: + return tensor[:, :, None, :] + elif len(tensor.shape) == 4: + return tensor[:, :, 0, :] + else: + raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.") + + +class Conv1dBlock(nn.Module): + """ + Conv1d --> GroupNorm --> Mish + """ + + def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): + super().__init__() + + self.conv1d = nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2) + self.group_norm = nn.GroupNorm(n_groups, out_channels) + self.mish = nn.Mish() + + def forward(self, x): + x = self.conv1d(x) + x = rearrange_dims(x) + x = self.group_norm(x) + x = rearrange_dims(x) + x = self.mish(x) + return x + + +# unet_rl.py +class ResidualTemporalBlock1D(nn.Module): + def __init__(self, inp_channels, out_channels, embed_dim, kernel_size=5): + super().__init__() + self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size) + self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size) + + self.time_emb_act = nn.Mish() + self.time_emb = nn.Linear(embed_dim, out_channels) + + self.residual_conv = ( + nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity() + ) + + def forward(self, x, t): + """ + Args: + x : [ batch_size x inp_channels x horizon ] + t : [ batch_size x embed_dim ] + + returns: + out : [ batch_size x out_channels x horizon ] + """ + t = self.time_emb_act(t) + t = self.time_emb(t) + out = self.conv_in(x) + rearrange_dims(t) + out = self.conv_out(out) + return out + self.residual_conv(x) + + def upsample_2d(hidden_states, kernel=None, factor=2, gain=1): r"""Upsample2D a batch of 2D images with the given filter. Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given @@ -492,10 +634,6 @@ def upfirdn2d_native(tensor, kernel, up=1, down=1, pad=(0, 0)): kernel_h, kernel_w = kernel.shape out = tensor.view(-1, in_h, 1, in_w, 1, minor) - - # Temporary workaround for mps specific issue: https://github.com/pytorch/pytorch/issues/84535 - if tensor.device.type == "mps": - out = out.to("cpu") out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) out = out.view(-1, in_h * up_y, in_w * up_x, minor) diff --git a/src/diffusers/models/unet_1d.py b/src/diffusers/models/unet_1d.py new file mode 100644 index 000000000000..29d1d707f55a --- /dev/null +++ b/src/diffusers/models/unet_1d.py @@ -0,0 +1,245 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn + +from ..configuration_utils import ConfigMixin, register_to_config +from ..modeling_utils import ModelMixin +from ..utils import BaseOutput +from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps +from .unet_1d_blocks import get_down_block, get_mid_block, get_out_block, get_up_block + + +@dataclass +class UNet1DOutput(BaseOutput): + """ + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, sample_size)`): + Hidden states output. Output of last layer of model. + """ + + sample: torch.FloatTensor + + +class UNet1DModel(ModelMixin, ConfigMixin): + r""" + UNet1DModel is a 1D UNet model that takes in a noisy sample and a timestep and returns sample shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library + implements for all the model (such as downloading or saving, etc.) + + Parameters: + sample_size (`int`, *optional*): Default length of sample. Should be adaptable at runtime. + in_channels (`int`, *optional*, defaults to 2): Number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 2): Number of channels in the output. + time_embedding_type (`str`, *optional*, defaults to `"fourier"`): Type of time embedding to use. + freq_shift (`float`, *optional*, defaults to 0.0): Frequency shift for fourier time embedding. + flip_sin_to_cos (`bool`, *optional*, defaults to : + obj:`False`): Whether to flip sin to cos for fourier time embedding. + down_block_types (`Tuple[str]`, *optional*, defaults to : + obj:`("DownBlock1D", "DownBlock1DNoSkip", "AttnDownBlock1D")`): Tuple of downsample block types. + up_block_types (`Tuple[str]`, *optional*, defaults to : + obj:`("UpBlock1D", "UpBlock1DNoSkip", "AttnUpBlock1D")`): Tuple of upsample block types. + block_out_channels (`Tuple[int]`, *optional*, defaults to : + obj:`(32, 32, 64)`): Tuple of block output channels. + mid_block_type (`str`, *optional*, defaults to "UNetMidBlock1D"): block type for middle of UNet. + out_block_type (`str`, *optional*, defaults to `None`): optional output processing of UNet. + act_fn (`str`, *optional*, defaults to None): optional activitation function in UNet blocks. + norm_num_groups (`int`, *optional*, defaults to 8): group norm member count in UNet blocks. + layers_per_block (`int`, *optional*, defaults to 1): added number of layers in a UNet block. + downsample_each_block (`int`, *optional*, defaults to False: + experimental feature for using a UNet without upsampling. + """ + + @register_to_config + def __init__( + self, + sample_size: int = 65536, + sample_rate: Optional[int] = None, + in_channels: int = 2, + out_channels: int = 2, + extra_in_channels: int = 0, + time_embedding_type: str = "fourier", + flip_sin_to_cos: bool = True, + use_timestep_embedding: bool = False, + freq_shift: float = 0.0, + down_block_types: Tuple[str] = ("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D"), + up_block_types: Tuple[str] = ("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"), + mid_block_type: Tuple[str] = "UNetMidBlock1D", + out_block_type: str = None, + block_out_channels: Tuple[int] = (32, 32, 64), + act_fn: str = None, + norm_num_groups: int = 8, + layers_per_block: int = 1, + downsample_each_block: bool = False, + ): + super().__init__() + self.sample_size = sample_size + + # time + if time_embedding_type == "fourier": + self.time_proj = GaussianFourierProjection( + embedding_size=8, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos + ) + timestep_input_dim = 2 * block_out_channels[0] + elif time_embedding_type == "positional": + self.time_proj = Timesteps( + block_out_channels[0], flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=freq_shift + ) + timestep_input_dim = block_out_channels[0] + + if use_timestep_embedding: + time_embed_dim = block_out_channels[0] * 4 + self.time_mlp = TimestepEmbedding( + in_channels=timestep_input_dim, + time_embed_dim=time_embed_dim, + act_fn=act_fn, + out_dim=block_out_channels[0], + ) + + self.down_blocks = nn.ModuleList([]) + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + self.out_block = None + + # down + output_channel = in_channels + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + + if i == 0: + input_channel += extra_in_channels + + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=block_out_channels[0], + add_downsample=not is_final_block or downsample_each_block, + ) + self.down_blocks.append(down_block) + + # mid + self.mid_block = get_mid_block( + mid_block_type, + in_channels=block_out_channels[-1], + mid_channels=block_out_channels[-1], + out_channels=block_out_channels[-1], + embed_dim=block_out_channels[0], + num_layers=layers_per_block, + add_downsample=downsample_each_block, + ) + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + if out_block_type is None: + final_upsample_channels = out_channels + else: + final_upsample_channels = block_out_channels[0] + + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = ( + reversed_block_out_channels[i + 1] if i < len(up_block_types) - 1 else final_upsample_channels + ) + + is_final_block = i == len(block_out_channels) - 1 + + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block, + in_channels=prev_output_channel, + out_channels=output_channel, + temb_channels=block_out_channels[0], + add_upsample=not is_final_block, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32) + self.out_block = get_out_block( + out_block_type=out_block_type, + num_groups_out=num_groups_out, + embed_dim=block_out_channels[0], + out_channels=out_channels, + act_fn=act_fn, + fc_dim=block_out_channels[-1] // 4, + ) + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + return_dict: bool = True, + ) -> Union[UNet1DOutput, Tuple]: + r""" + Args: + sample (`torch.FloatTensor`): `(batch_size, sample_size, num_channels)` noisy inputs tensor + timestep (`torch.FloatTensor` or `float` or `int): (batch) timesteps + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_1d.UNet1DOutput`] instead of a plain tuple. + + Returns: + [`~models.unet_1d.UNet1DOutput`] or `tuple`: [`~models.unet_1d.UNet1DOutput`] if `return_dict` is True, + otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. + """ + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) + elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + timestep_embed = self.time_proj(timesteps) + if self.config.use_timestep_embedding: + timestep_embed = self.time_mlp(timestep_embed) + else: + timestep_embed = timestep_embed[..., None] + timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]).to(sample.dtype) + + # 2. down + down_block_res_samples = () + for downsample_block in self.down_blocks: + sample, res_samples = downsample_block(hidden_states=sample, temb=timestep_embed) + down_block_res_samples += res_samples + + # 3. mid + if self.mid_block: + sample = self.mid_block(sample, timestep_embed) + + # 4. up + for i, upsample_block in enumerate(self.up_blocks): + res_samples = down_block_res_samples[-1:] + down_block_res_samples = down_block_res_samples[:-1] + sample = upsample_block(sample, res_hidden_states_tuple=res_samples, temb=timestep_embed) + + # 5. post-process + if self.out_block: + sample = self.out_block(sample, timestep_embed) + + if not return_dict: + return (sample,) + + return UNet1DOutput(sample=sample) diff --git a/src/diffusers/models/unet_1d_blocks.py b/src/diffusers/models/unet_1d_blocks.py new file mode 100644 index 000000000000..fc758ebbb044 --- /dev/null +++ b/src/diffusers/models/unet_1d_blocks.py @@ -0,0 +1,668 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math + +import torch +import torch.nn.functional as F +from torch import nn + +from .resnet import Downsample1D, ResidualTemporalBlock1D, Upsample1D, rearrange_dims + + +class DownResnetBlock1D(nn.Module): + def __init__( + self, + in_channels, + out_channels=None, + num_layers=1, + conv_shortcut=False, + temb_channels=32, + groups=32, + groups_out=None, + non_linearity=None, + time_embedding_norm="default", + output_scale_factor=1.0, + add_downsample=True, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + self.time_embedding_norm = time_embedding_norm + self.add_downsample = add_downsample + self.output_scale_factor = output_scale_factor + + if groups_out is None: + groups_out = groups + + # there will always be at least one resnet + resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=temb_channels)] + + for _ in range(num_layers): + resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels)) + + self.resnets = nn.ModuleList(resnets) + + if non_linearity == "swish": + self.nonlinearity = lambda x: F.silu(x) + elif non_linearity == "mish": + self.nonlinearity = nn.Mish() + elif non_linearity == "silu": + self.nonlinearity = nn.SiLU() + else: + self.nonlinearity = None + + self.downsample = None + if add_downsample: + self.downsample = Downsample1D(out_channels, use_conv=True, padding=1) + + def forward(self, hidden_states, temb=None): + output_states = () + + hidden_states = self.resnets[0](hidden_states, temb) + for resnet in self.resnets[1:]: + hidden_states = resnet(hidden_states, temb) + + output_states += (hidden_states,) + + if self.nonlinearity is not None: + hidden_states = self.nonlinearity(hidden_states) + + if self.downsample is not None: + hidden_states = self.downsample(hidden_states) + + return hidden_states, output_states + + +class UpResnetBlock1D(nn.Module): + def __init__( + self, + in_channels, + out_channels=None, + num_layers=1, + temb_channels=32, + groups=32, + groups_out=None, + non_linearity=None, + time_embedding_norm="default", + output_scale_factor=1.0, + add_upsample=True, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.time_embedding_norm = time_embedding_norm + self.add_upsample = add_upsample + self.output_scale_factor = output_scale_factor + + if groups_out is None: + groups_out = groups + + # there will always be at least one resnet + resnets = [ResidualTemporalBlock1D(2 * in_channels, out_channels, embed_dim=temb_channels)] + + for _ in range(num_layers): + resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels)) + + self.resnets = nn.ModuleList(resnets) + + if non_linearity == "swish": + self.nonlinearity = lambda x: F.silu(x) + elif non_linearity == "mish": + self.nonlinearity = nn.Mish() + elif non_linearity == "silu": + self.nonlinearity = nn.SiLU() + else: + self.nonlinearity = None + + self.upsample = None + if add_upsample: + self.upsample = Upsample1D(out_channels, use_conv_transpose=True) + + def forward(self, hidden_states, res_hidden_states_tuple=None, temb=None): + if res_hidden_states_tuple is not None: + res_hidden_states = res_hidden_states_tuple[-1] + hidden_states = torch.cat((hidden_states, res_hidden_states), dim=1) + + hidden_states = self.resnets[0](hidden_states, temb) + for resnet in self.resnets[1:]: + hidden_states = resnet(hidden_states, temb) + + if self.nonlinearity is not None: + hidden_states = self.nonlinearity(hidden_states) + + if self.upsample is not None: + hidden_states = self.upsample(hidden_states) + + return hidden_states + + +class ValueFunctionMidBlock1D(nn.Module): + def __init__(self, in_channels, out_channels, embed_dim): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.embed_dim = embed_dim + + self.res1 = ResidualTemporalBlock1D(in_channels, in_channels // 2, embed_dim=embed_dim) + self.down1 = Downsample1D(out_channels // 2, use_conv=True) + self.res2 = ResidualTemporalBlock1D(in_channels // 2, in_channels // 4, embed_dim=embed_dim) + self.down2 = Downsample1D(out_channels // 4, use_conv=True) + + def forward(self, x, temb=None): + x = self.res1(x, temb) + x = self.down1(x) + x = self.res2(x, temb) + x = self.down2(x) + return x + + +class MidResTemporalBlock1D(nn.Module): + def __init__( + self, + in_channels, + out_channels, + embed_dim, + num_layers: int = 1, + add_downsample: bool = False, + add_upsample: bool = False, + non_linearity=None, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.add_downsample = add_downsample + + # there will always be at least one resnet + resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=embed_dim)] + + for _ in range(num_layers): + resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=embed_dim)) + + self.resnets = nn.ModuleList(resnets) + + if non_linearity == "swish": + self.nonlinearity = lambda x: F.silu(x) + elif non_linearity == "mish": + self.nonlinearity = nn.Mish() + elif non_linearity == "silu": + self.nonlinearity = nn.SiLU() + else: + self.nonlinearity = None + + self.upsample = None + if add_upsample: + self.upsample = Downsample1D(out_channels, use_conv=True) + + self.downsample = None + if add_downsample: + self.downsample = Downsample1D(out_channels, use_conv=True) + + if self.upsample and self.downsample: + raise ValueError("Block cannot downsample and upsample") + + def forward(self, hidden_states, temb): + hidden_states = self.resnets[0](hidden_states, temb) + for resnet in self.resnets[1:]: + hidden_states = resnet(hidden_states, temb) + + if self.upsample: + hidden_states = self.upsample(hidden_states) + if self.downsample: + self.downsample = self.downsample(hidden_states) + + return hidden_states + + +class OutConv1DBlock(nn.Module): + def __init__(self, num_groups_out, out_channels, embed_dim, act_fn): + super().__init__() + self.final_conv1d_1 = nn.Conv1d(embed_dim, embed_dim, 5, padding=2) + self.final_conv1d_gn = nn.GroupNorm(num_groups_out, embed_dim) + if act_fn == "silu": + self.final_conv1d_act = nn.SiLU() + if act_fn == "mish": + self.final_conv1d_act = nn.Mish() + self.final_conv1d_2 = nn.Conv1d(embed_dim, out_channels, 1) + + def forward(self, hidden_states, temb=None): + hidden_states = self.final_conv1d_1(hidden_states) + hidden_states = rearrange_dims(hidden_states) + hidden_states = self.final_conv1d_gn(hidden_states) + hidden_states = rearrange_dims(hidden_states) + hidden_states = self.final_conv1d_act(hidden_states) + hidden_states = self.final_conv1d_2(hidden_states) + return hidden_states + + +class OutValueFunctionBlock(nn.Module): + def __init__(self, fc_dim, embed_dim): + super().__init__() + self.final_block = nn.ModuleList( + [ + nn.Linear(fc_dim + embed_dim, fc_dim // 2), + nn.Mish(), + nn.Linear(fc_dim // 2, 1), + ] + ) + + def forward(self, hidden_states, temb): + hidden_states = hidden_states.view(hidden_states.shape[0], -1) + hidden_states = torch.cat((hidden_states, temb), dim=-1) + for layer in self.final_block: + hidden_states = layer(hidden_states) + + return hidden_states + + +_kernels = { + "linear": [1 / 8, 3 / 8, 3 / 8, 1 / 8], + "cubic": [-0.01171875, -0.03515625, 0.11328125, 0.43359375, 0.43359375, 0.11328125, -0.03515625, -0.01171875], + "lanczos3": [ + 0.003689131001010537, + 0.015056144446134567, + -0.03399861603975296, + -0.066637322306633, + 0.13550527393817902, + 0.44638532400131226, + 0.44638532400131226, + 0.13550527393817902, + -0.066637322306633, + -0.03399861603975296, + 0.015056144446134567, + 0.003689131001010537, + ], +} + + +class Downsample1d(nn.Module): + def __init__(self, kernel="linear", pad_mode="reflect"): + super().__init__() + self.pad_mode = pad_mode + kernel_1d = torch.tensor(_kernels[kernel]) + self.pad = kernel_1d.shape[0] // 2 - 1 + self.register_buffer("kernel", kernel_1d) + + def forward(self, hidden_states): + hidden_states = F.pad(hidden_states, (self.pad,) * 2, self.pad_mode) + weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]]) + indices = torch.arange(hidden_states.shape[1], device=hidden_states.device) + weight[indices, indices] = self.kernel.to(weight) + return F.conv1d(hidden_states, weight, stride=2) + + +class Upsample1d(nn.Module): + def __init__(self, kernel="linear", pad_mode="reflect"): + super().__init__() + self.pad_mode = pad_mode + kernel_1d = torch.tensor(_kernels[kernel]) * 2 + self.pad = kernel_1d.shape[0] // 2 - 1 + self.register_buffer("kernel", kernel_1d) + + def forward(self, hidden_states, temb=None): + hidden_states = F.pad(hidden_states, ((self.pad + 1) // 2,) * 2, self.pad_mode) + weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]]) + indices = torch.arange(hidden_states.shape[1], device=hidden_states.device) + weight[indices, indices] = self.kernel.to(weight) + return F.conv_transpose1d(hidden_states, weight, stride=2, padding=self.pad * 2 + 1) + + +class SelfAttention1d(nn.Module): + def __init__(self, in_channels, n_head=1, dropout_rate=0.0): + super().__init__() + self.channels = in_channels + self.group_norm = nn.GroupNorm(1, num_channels=in_channels) + self.num_heads = n_head + + self.query = nn.Linear(self.channels, self.channels) + self.key = nn.Linear(self.channels, self.channels) + self.value = nn.Linear(self.channels, self.channels) + + self.proj_attn = nn.Linear(self.channels, self.channels, 1) + + self.dropout = nn.Dropout(dropout_rate, inplace=True) + + def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor: + new_projection_shape = projection.size()[:-1] + (self.num_heads, -1) + # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D) + new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3) + return new_projection + + def forward(self, hidden_states): + residual = hidden_states + batch, channel_dim, seq = hidden_states.shape + + hidden_states = self.group_norm(hidden_states) + hidden_states = hidden_states.transpose(1, 2) + + query_proj = self.query(hidden_states) + key_proj = self.key(hidden_states) + value_proj = self.value(hidden_states) + + query_states = self.transpose_for_scores(query_proj) + key_states = self.transpose_for_scores(key_proj) + value_states = self.transpose_for_scores(value_proj) + + scale = 1 / math.sqrt(math.sqrt(key_states.shape[-1])) + + attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale) + attention_probs = torch.softmax(attention_scores, dim=-1) + + # compute attention output + hidden_states = torch.matmul(attention_probs, value_states) + + hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous() + new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,) + hidden_states = hidden_states.view(new_hidden_states_shape) + + # compute next hidden_states + hidden_states = self.proj_attn(hidden_states) + hidden_states = hidden_states.transpose(1, 2) + hidden_states = self.dropout(hidden_states) + + output = hidden_states + residual + + return output + + +class ResConvBlock(nn.Module): + def __init__(self, in_channels, mid_channels, out_channels, is_last=False): + super().__init__() + self.is_last = is_last + self.has_conv_skip = in_channels != out_channels + + if self.has_conv_skip: + self.conv_skip = nn.Conv1d(in_channels, out_channels, 1, bias=False) + + self.conv_1 = nn.Conv1d(in_channels, mid_channels, 5, padding=2) + self.group_norm_1 = nn.GroupNorm(1, mid_channels) + self.gelu_1 = nn.GELU() + self.conv_2 = nn.Conv1d(mid_channels, out_channels, 5, padding=2) + + if not self.is_last: + self.group_norm_2 = nn.GroupNorm(1, out_channels) + self.gelu_2 = nn.GELU() + + def forward(self, hidden_states): + residual = self.conv_skip(hidden_states) if self.has_conv_skip else hidden_states + + hidden_states = self.conv_1(hidden_states) + hidden_states = self.group_norm_1(hidden_states) + hidden_states = self.gelu_1(hidden_states) + hidden_states = self.conv_2(hidden_states) + + if not self.is_last: + hidden_states = self.group_norm_2(hidden_states) + hidden_states = self.gelu_2(hidden_states) + + output = hidden_states + residual + return output + + +class UNetMidBlock1D(nn.Module): + def __init__(self, mid_channels, in_channels, out_channels=None): + super().__init__() + + out_channels = in_channels if out_channels is None else out_channels + + # there is always at least one resnet + self.down = Downsample1d("cubic") + resnets = [ + ResConvBlock(in_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, out_channels), + ] + attentions = [ + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(out_channels, out_channels // 32), + ] + self.up = Upsample1d(kernel="cubic") + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states, temb=None): + hidden_states = self.down(hidden_states) + for attn, resnet in zip(self.attentions, self.resnets): + hidden_states = resnet(hidden_states) + hidden_states = attn(hidden_states) + + hidden_states = self.up(hidden_states) + + return hidden_states + + +class AttnDownBlock1D(nn.Module): + def __init__(self, out_channels, in_channels, mid_channels=None): + super().__init__() + mid_channels = out_channels if mid_channels is None else mid_channels + + self.down = Downsample1d("cubic") + resnets = [ + ResConvBlock(in_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, out_channels), + ] + attentions = [ + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(out_channels, out_channels // 32), + ] + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states, temb=None): + hidden_states = self.down(hidden_states) + + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states) + hidden_states = attn(hidden_states) + + return hidden_states, (hidden_states,) + + +class DownBlock1D(nn.Module): + def __init__(self, out_channels, in_channels, mid_channels=None): + super().__init__() + mid_channels = out_channels if mid_channels is None else mid_channels + + self.down = Downsample1d("cubic") + resnets = [ + ResConvBlock(in_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, out_channels), + ] + + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states, temb=None): + hidden_states = self.down(hidden_states) + + for resnet in self.resnets: + hidden_states = resnet(hidden_states) + + return hidden_states, (hidden_states,) + + +class DownBlock1DNoSkip(nn.Module): + def __init__(self, out_channels, in_channels, mid_channels=None): + super().__init__() + mid_channels = out_channels if mid_channels is None else mid_channels + + resnets = [ + ResConvBlock(in_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, out_channels), + ] + + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states, temb=None): + hidden_states = torch.cat([hidden_states, temb], dim=1) + for resnet in self.resnets: + hidden_states = resnet(hidden_states) + + return hidden_states, (hidden_states,) + + +class AttnUpBlock1D(nn.Module): + def __init__(self, in_channels, out_channels, mid_channels=None): + super().__init__() + mid_channels = out_channels if mid_channels is None else mid_channels + + resnets = [ + ResConvBlock(2 * in_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, out_channels), + ] + attentions = [ + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(out_channels, out_channels // 32), + ] + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + self.up = Upsample1d(kernel="cubic") + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None): + res_hidden_states = res_hidden_states_tuple[-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states) + hidden_states = attn(hidden_states) + + hidden_states = self.up(hidden_states) + + return hidden_states + + +class UpBlock1D(nn.Module): + def __init__(self, in_channels, out_channels, mid_channels=None): + super().__init__() + mid_channels = in_channels if mid_channels is None else mid_channels + + resnets = [ + ResConvBlock(2 * in_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, out_channels), + ] + + self.resnets = nn.ModuleList(resnets) + self.up = Upsample1d(kernel="cubic") + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None): + res_hidden_states = res_hidden_states_tuple[-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + for resnet in self.resnets: + hidden_states = resnet(hidden_states) + + hidden_states = self.up(hidden_states) + + return hidden_states + + +class UpBlock1DNoSkip(nn.Module): + def __init__(self, in_channels, out_channels, mid_channels=None): + super().__init__() + mid_channels = in_channels if mid_channels is None else mid_channels + + resnets = [ + ResConvBlock(2 * in_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, out_channels, is_last=True), + ] + + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None): + res_hidden_states = res_hidden_states_tuple[-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + for resnet in self.resnets: + hidden_states = resnet(hidden_states) + + return hidden_states + + +def get_down_block(down_block_type, num_layers, in_channels, out_channels, temb_channels, add_downsample): + if down_block_type == "DownResnetBlock1D": + return DownResnetBlock1D( + in_channels=in_channels, + num_layers=num_layers, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + ) + elif down_block_type == "DownBlock1D": + return DownBlock1D(out_channels=out_channels, in_channels=in_channels) + elif down_block_type == "AttnDownBlock1D": + return AttnDownBlock1D(out_channels=out_channels, in_channels=in_channels) + elif down_block_type == "DownBlock1DNoSkip": + return DownBlock1DNoSkip(out_channels=out_channels, in_channels=in_channels) + raise ValueError(f"{down_block_type} does not exist.") + + +def get_up_block(up_block_type, num_layers, in_channels, out_channels, temb_channels, add_upsample): + if up_block_type == "UpResnetBlock1D": + return UpResnetBlock1D( + in_channels=in_channels, + num_layers=num_layers, + out_channels=out_channels, + temb_channels=temb_channels, + add_upsample=add_upsample, + ) + elif up_block_type == "UpBlock1D": + return UpBlock1D(in_channels=in_channels, out_channels=out_channels) + elif up_block_type == "AttnUpBlock1D": + return AttnUpBlock1D(in_channels=in_channels, out_channels=out_channels) + elif up_block_type == "UpBlock1DNoSkip": + return UpBlock1DNoSkip(in_channels=in_channels, out_channels=out_channels) + raise ValueError(f"{up_block_type} does not exist.") + + +def get_mid_block(mid_block_type, num_layers, in_channels, mid_channels, out_channels, embed_dim, add_downsample): + if mid_block_type == "MidResTemporalBlock1D": + return MidResTemporalBlock1D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + embed_dim=embed_dim, + add_downsample=add_downsample, + ) + elif mid_block_type == "ValueFunctionMidBlock1D": + return ValueFunctionMidBlock1D(in_channels=in_channels, out_channels=out_channels, embed_dim=embed_dim) + elif mid_block_type == "UNetMidBlock1D": + return UNetMidBlock1D(in_channels=in_channels, mid_channels=mid_channels, out_channels=out_channels) + raise ValueError(f"{mid_block_type} does not exist.") + + +def get_out_block(*, out_block_type, num_groups_out, embed_dim, out_channels, act_fn, fc_dim): + if out_block_type == "OutConv1DBlock": + return OutConv1DBlock(num_groups_out, out_channels, embed_dim, act_fn) + elif out_block_type == "ValueFunction": + return OutValueFunctionBlock(fc_dim, embed_dim) + return None diff --git a/src/diffusers/models/unet_2d.py b/src/diffusers/models/unet_2d.py index b580b4ed4056..5b337f482cc5 100644 --- a/src/diffusers/models/unet_2d.py +++ b/src/diffusers/models/unet_2d.py @@ -21,7 +21,7 @@ from ..modeling_utils import ModelMixin from ..utils import BaseOutput from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps -from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block +from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block @dataclass @@ -43,15 +43,15 @@ class UNet2DModel(ModelMixin, ConfigMixin): implements for all the model (such as downloading or saving, etc.) Parameters: - sample_size (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, *optional*): - Input sample size. + sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. in_channels (`int`, *optional*, defaults to 3): Number of channels in the input image. out_channels (`int`, *optional*, defaults to 3): Number of channels in the output. center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. time_embedding_type (`str`, *optional*, defaults to `"positional"`): Type of time embedding to use. freq_shift (`int`, *optional*, defaults to 0): Frequency shift for fourier time embedding. flip_sin_to_cos (`bool`, *optional*, defaults to : - obj:`False`): Whether to flip sin to cos for fourier time embedding. + obj:`True`): Whether to flip sin to cos for fourier time embedding. down_block_types (`Tuple[str]`, *optional*, defaults to : obj:`("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`): Tuple of downsample block types. @@ -71,7 +71,7 @@ class UNet2DModel(ModelMixin, ConfigMixin): @register_to_config def __init__( self, - sample_size: Optional[int] = None, + sample_size: Optional[Union[int, Tuple[int, int]]] = None, in_channels: int = 3, out_channels: int = 3, center_input_sample: bool = False, @@ -175,7 +175,7 @@ def __init__( num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32) self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=norm_eps) self.conv_act = nn.SiLU() - self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) + self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1) def forward( self, @@ -209,6 +209,11 @@ def forward( timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device) t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=self.dtype) emb = self.time_embedding(t_emb) # 2. pre-process @@ -242,9 +247,7 @@ def forward( sample = upsample_block(sample, res_samples, emb) # 6. post-process - # make sure hidden states is in float32 - # when running in half-precision - sample = self.conv_norm_out(sample.float()).type(sample.dtype) + sample = self.conv_norm_out(sample) sample = self.conv_act(sample) sample = self.conv_out(sample) diff --git a/src/diffusers/models/unet_blocks.py b/src/diffusers/models/unet_2d_blocks.py similarity index 84% rename from src/diffusers/models/unet_blocks.py rename to src/diffusers/models/unet_2d_blocks.py index f433eb5db06e..cce7e7fd5a90 100644 --- a/src/diffusers/models/unet_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -15,7 +15,7 @@ import torch from torch import nn -from .attention import AttentionBlock, SpatialTransformer +from .attention import AttentionBlock, DualTransformer2DModel, Transformer2DModel from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D @@ -32,6 +32,9 @@ def get_down_block( resnet_groups=None, cross_attention_dim=None, downsample_padding=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, ): down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type if down_block_type == "DownBlock2D": @@ -74,6 +77,9 @@ def get_down_block( downsample_padding=downsample_padding, cross_attention_dim=cross_attention_dim, attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, ) elif down_block_type == "SkipDownBlock2D": return SkipDownBlock2D( @@ -109,6 +115,19 @@ def get_down_block( resnet_groups=resnet_groups, downsample_padding=downsample_padding, ) + elif down_block_type == "AttnDownEncoderBlock2D": + return AttnDownEncoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + attn_num_head_channels=attn_num_head_channels, + ) + raise ValueError(f"{down_block_type} does not exist.") def get_up_block( @@ -124,6 +143,9 @@ def get_up_block( attn_num_head_channels, resnet_groups=None, cross_attention_dim=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, ): up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type if up_block_type == "UpBlock2D": @@ -153,6 +175,9 @@ def get_up_block( resnet_groups=resnet_groups, cross_attention_dim=cross_attention_dim, attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, ) elif up_block_type == "AttnUpBlock2D": return AttnUpBlock2D( @@ -200,6 +225,17 @@ def get_up_block( resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, ) + elif up_block_type == "AttnUpDecoderBlock2D": + return AttnUpDecoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + attn_num_head_channels=attn_num_head_channels, + ) raise ValueError(f"{up_block_type} does not exist.") @@ -218,7 +254,6 @@ def __init__( attn_num_head_channels=1, attention_type="default", output_scale_factor=1.0, - **kwargs, ): super().__init__() @@ -249,7 +284,7 @@ def __init__( num_head_channels=attn_num_head_channels, rescale_output_factor=output_scale_factor, eps=resnet_eps, - num_groups=resnet_groups, + norm_num_groups=resnet_groups, ) ) resnets.append( @@ -298,7 +333,8 @@ def __init__( attention_type="default", output_scale_factor=1.0, cross_attention_dim=1280, - **kwargs, + dual_cross_attention=False, + use_linear_projection=False, ): super().__init__() @@ -324,16 +360,29 @@ def __init__( attentions = [] for _ in range(num_layers): - attentions.append( - SpatialTransformer( - in_channels, - attn_num_head_channels, - in_channels // attn_num_head_channels, - depth=1, - context_dim=cross_attention_dim, - num_groups=resnet_groups, + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + attn_num_head_channels, + in_channels // attn_num_head_channels, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + attn_num_head_channels, + in_channels // attn_num_head_channels, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) ) - ) resnets.append( ResnetBlock2D( in_channels=in_channels, @@ -353,24 +402,30 @@ def __init__( self.resnets = nn.ModuleList(resnets) def set_attention_slice(self, slice_size): - if slice_size is not None and self.attn_num_head_channels % slice_size != 0: + head_dims = self.attn_num_head_channels + head_dims = [head_dims] if isinstance(head_dims, int) else head_dims + if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims): raise ValueError( - f"Make sure slice_size {slice_size} is a divisor of " - f"the number of heads used in cross_attention {self.attn_num_head_channels}" + f"Make sure slice_size {slice_size} is a common divisor of " + f"the number of heads used in cross_attention: {head_dims}" ) - if slice_size is not None and slice_size > self.attn_num_head_channels: + if slice_size is not None and slice_size > min(head_dims): raise ValueError( - f"Chunk_size {slice_size} has to be smaller or equal to " - f"the number of heads used in cross_attention {self.attn_num_head_channels}" + f"slice_size {slice_size} has to be smaller or equal to " + f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}" ) for attn in self.attentions: attn._set_attention_slice(slice_size) + def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): + for attn in self.attentions: + attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) + def forward(self, hidden_states, temb=None, encoder_hidden_states=None): hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): - hidden_states = attn(hidden_states, encoder_hidden_states) + hidden_states = attn(hidden_states, encoder_hidden_states).sample hidden_states = resnet(hidden_states, temb) return hidden_states @@ -423,7 +478,7 @@ def __init__( num_head_channels=attn_num_head_channels, rescale_output_factor=output_scale_factor, eps=resnet_eps, - num_groups=resnet_groups, + norm_num_groups=resnet_groups, ) ) @@ -434,7 +489,7 @@ def __init__( self.downsamplers = nn.ModuleList( [ Downsample2D( - in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" ) ] ) @@ -477,6 +532,9 @@ def __init__( output_scale_factor=1.0, downsample_padding=1, add_downsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, ): super().__init__() resnets = [] @@ -501,16 +559,30 @@ def __init__( pre_norm=resnet_pre_norm, ) ) - attentions.append( - SpatialTransformer( - out_channels, - attn_num_head_channels, - out_channels // attn_num_head_channels, - depth=1, - context_dim=cross_attention_dim, - num_groups=resnet_groups, + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) ) - ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) @@ -518,7 +590,7 @@ def __init__( self.downsamplers = nn.ModuleList( [ Downsample2D( - in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" ) ] ) @@ -528,39 +600,48 @@ def __init__( self.gradient_checkpointing = False def set_attention_slice(self, slice_size): - if slice_size is not None and self.attn_num_head_channels % slice_size != 0: + head_dims = self.attn_num_head_channels + head_dims = [head_dims] if isinstance(head_dims, int) else head_dims + if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims): raise ValueError( - f"Make sure slice_size {slice_size} is a divisor of " - f"the number of heads used in cross_attention {self.attn_num_head_channels}" + f"Make sure slice_size {slice_size} is a common divisor of " + f"the number of heads used in cross_attention: {head_dims}" ) - if slice_size is not None and slice_size > self.attn_num_head_channels: + if slice_size is not None and slice_size > min(head_dims): raise ValueError( - f"Chunk_size {slice_size} has to be smaller or equal to " - f"the number of heads used in cross_attention {self.attn_num_head_channels}" + f"slice_size {slice_size} has to be smaller or equal to " + f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}" ) for attn in self.attentions: attn._set_attention_slice(slice_size) + def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): + for attn in self.attentions: + attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) + def forward(self, hidden_states, temb=None, encoder_hidden_states=None): output_states = () for resnet, attn in zip(self.resnets, self.attentions): if self.training and self.gradient_checkpointing: - def create_custom_forward(module): + def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): - return module(*inputs) + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) return custom_forward hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn), hidden_states, encoder_hidden_states - ) + create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states + )[0] else: hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states, context=encoder_hidden_states) + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample output_states += (hidden_states,) @@ -616,7 +697,7 @@ def __init__( self.downsamplers = nn.ModuleList( [ Downsample2D( - in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" ) ] ) @@ -694,7 +775,7 @@ def __init__( self.downsamplers = nn.ModuleList( [ Downsample2D( - in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" ) ] ) @@ -755,7 +836,7 @@ def __init__( num_head_channels=attn_num_head_channels, rescale_output_factor=output_scale_factor, eps=resnet_eps, - num_groups=resnet_groups, + norm_num_groups=resnet_groups, ) ) @@ -766,7 +847,7 @@ def __init__( self.downsamplers = nn.ModuleList( [ Downsample2D( - in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" ) ] ) @@ -851,7 +932,7 @@ def __init__( down=True, kernel="fir", ) - self.downsamplers = nn.ModuleList([FirDownsample2D(in_channels, out_channels=out_channels)]) + self.downsamplers = nn.ModuleList([FirDownsample2D(out_channels, out_channels=out_channels)]) self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1)) else: self.resnet_down = None @@ -931,7 +1012,7 @@ def __init__( down=True, kernel="fir", ) - self.downsamplers = nn.ModuleList([FirDownsample2D(in_channels, out_channels=out_channels)]) + self.downsamplers = nn.ModuleList([FirDownsample2D(out_channels, out_channels=out_channels)]) self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1)) else: self.resnet_down = None @@ -1006,7 +1087,7 @@ def __init__( num_head_channels=attn_num_head_channels, rescale_output_factor=output_scale_factor, eps=resnet_eps, - num_groups=resnet_groups, + norm_num_groups=resnet_groups, ) ) @@ -1053,8 +1134,10 @@ def __init__( cross_attention_dim=1280, attention_type="default", output_scale_factor=1.0, - downsample_padding=1, add_upsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, ): super().__init__() resnets = [] @@ -1081,16 +1164,30 @@ def __init__( pre_norm=resnet_pre_norm, ) ) - attentions.append( - SpatialTransformer( - out_channels, - attn_num_head_channels, - out_channels // attn_num_head_channels, - depth=1, - context_dim=cross_attention_dim, - num_groups=resnet_groups, + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) ) - ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) @@ -1102,15 +1199,17 @@ def __init__( self.gradient_checkpointing = False def set_attention_slice(self, slice_size): - if slice_size is not None and self.attn_num_head_channels % slice_size != 0: + head_dims = self.attn_num_head_channels + head_dims = [head_dims] if isinstance(head_dims, int) else head_dims + if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims): raise ValueError( - f"Make sure slice_size {slice_size} is a divisor of " - f"the number of heads used in cross_attention {self.attn_num_head_channels}" + f"Make sure slice_size {slice_size} is a common divisor of " + f"the number of heads used in cross_attention: {head_dims}" ) - if slice_size is not None and slice_size > self.attn_num_head_channels: + if slice_size is not None and slice_size > min(head_dims): raise ValueError( - f"Chunk_size {slice_size} has to be smaller or equal to " - f"the number of heads used in cross_attention {self.attn_num_head_channels}" + f"slice_size {slice_size} has to be smaller or equal to " + f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}" ) for attn in self.attentions: @@ -1118,6 +1217,10 @@ def set_attention_slice(self, slice_size): self.gradient_checkpointing = False + def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): + for attn in self.attentions: + attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) + def forward( self, hidden_states, @@ -1134,19 +1237,22 @@ def forward( if self.training and self.gradient_checkpointing: - def create_custom_forward(module): + def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): - return module(*inputs) + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) return custom_forward hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn), hidden_states, encoder_hidden_states - ) + create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states + )[0] else: hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states, context=encoder_hidden_states) + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample if self.upsamplers is not None: for upsampler in self.upsamplers: @@ -1326,7 +1432,7 @@ def __init__( num_head_channels=attn_num_head_channels, rescale_output_factor=output_scale_factor, eps=resnet_eps, - num_groups=resnet_groups, + norm_num_groups=resnet_groups, ) ) diff --git a/src/diffusers/models/unet_blocks_flax.py b/src/diffusers/models/unet_2d_blocks_flax.py similarity index 98% rename from src/diffusers/models/unet_blocks_flax.py rename to src/diffusers/models/unet_2d_blocks_flax.py index baa71beabe35..5798385b9d28 100644 --- a/src/diffusers/models/unet_blocks_flax.py +++ b/src/diffusers/models/unet_2d_blocks_flax.py @@ -15,7 +15,7 @@ import flax.linen as nn import jax.numpy as jnp -from .attention_flax import FlaxSpatialTransformer +from .attention_flax import FlaxTransformer2DModel from .resnet_flax import FlaxDownsample2D, FlaxResnetBlock2D, FlaxUpsample2D @@ -63,7 +63,7 @@ def setup(self): ) resnets.append(res_block) - attn_block = FlaxSpatialTransformer( + attn_block = FlaxTransformer2DModel( in_channels=self.out_channels, n_heads=self.attn_num_head_channels, d_head=self.out_channels // self.attn_num_head_channels, @@ -196,7 +196,7 @@ def setup(self): ) resnets.append(res_block) - attn_block = FlaxSpatialTransformer( + attn_block = FlaxTransformer2DModel( in_channels=self.out_channels, n_heads=self.attn_num_head_channels, d_head=self.out_channels // self.attn_num_head_channels, @@ -326,7 +326,7 @@ def setup(self): attentions = [] for _ in range(self.num_layers): - attn_block = FlaxSpatialTransformer( + attn_block = FlaxTransformer2DModel( in_channels=self.in_channels, n_heads=self.attn_num_head_channels, d_head=self.in_channels // self.attn_num_head_channels, diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index a65e238969dd..1b43f960d997 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -22,7 +22,7 @@ from ..modeling_utils import ModelMixin from ..utils import BaseOutput, logging from .embeddings import TimestepEmbedding, Timesteps -from .unet_blocks import ( +from .unet_2d_blocks import ( CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, @@ -56,7 +56,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): implements for all the models (such as downloading or saving, etc.) Parameters: - sample_size (`int`, *optional*): The size of the input sample. + sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. @@ -97,6 +98,7 @@ def __init__( "DownBlock2D", ), up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), + only_cross_attention: Union[bool, Tuple[bool]] = False, block_out_channels: Tuple[int] = (320, 640, 1280, 1280), layers_per_block: int = 2, downsample_padding: int = 1, @@ -105,7 +107,10 @@ def __init__( norm_num_groups: int = 32, norm_eps: float = 1e-5, cross_attention_dim: int = 1280, - attention_head_dim: int = 8, + attention_head_dim: Union[int, Tuple[int]] = 8, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + num_class_embeds: Optional[int] = None, ): super().__init__() @@ -121,10 +126,20 @@ def __init__( self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + # class embedding + if num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + self.down_blocks = nn.ModuleList([]) self.mid_block = None self.up_blocks = nn.ModuleList([]) + if isinstance(only_cross_attention, bool): + only_cross_attention = [only_cross_attention] * len(down_block_types) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + # down output_channel = block_out_channels[0] for i, down_block_type in enumerate(down_block_types): @@ -143,8 +158,11 @@ def __init__( resnet_act_fn=act_fn, resnet_groups=norm_num_groups, cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attention_head_dim, + attn_num_head_channels=attention_head_dim[i], downsample_padding=downsample_padding, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], ) self.down_blocks.append(down_block) @@ -157,8 +175,10 @@ def __init__( output_scale_factor=mid_block_scale_factor, resnet_time_scale_shift="default", cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attention_head_dim, + attn_num_head_channels=attention_head_dim[-1], resnet_groups=norm_num_groups, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, ) # count how many layers upsample the images @@ -166,6 +186,8 @@ def __init__( # up reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_attention_head_dim = list(reversed(attention_head_dim)) + only_cross_attention = list(reversed(only_cross_attention)) output_channel = reversed_block_out_channels[0] for i, up_block_type in enumerate(up_block_types): is_final_block = i == len(block_out_channels) - 1 @@ -193,7 +215,10 @@ def __init__( resnet_act_fn=act_fn, resnet_groups=norm_num_groups, cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attention_head_dim, + attn_num_head_channels=reversed_attention_head_dim[i], + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], ) self.up_blocks.append(up_block) prev_output_channel = output_channel @@ -201,18 +226,20 @@ def __init__( # out self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) self.conv_act = nn.SiLU() - self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) + self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1) def set_attention_slice(self, slice_size): - if slice_size is not None and self.config.attention_head_dim % slice_size != 0: + head_dims = self.config.attention_head_dim + head_dims = [head_dims] if isinstance(head_dims, int) else head_dims + if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims): raise ValueError( - f"Make sure slice_size {slice_size} is a divisor of " - f"the number of heads used in cross_attention {self.config.attention_head_dim}" + f"Make sure slice_size {slice_size} is a common divisor of " + f"the number of heads used in cross_attention: {head_dims}" ) - if slice_size is not None and slice_size > self.config.attention_head_dim: + if slice_size is not None and slice_size > min(head_dims): raise ValueError( - f"Chunk_size {slice_size} has to be smaller or equal to " - f"the number of heads used in cross_attention {self.config.attention_head_dim}" + f"slice_size {slice_size} has to be smaller or equal to " + f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}" ) for block in self.down_blocks: @@ -225,6 +252,17 @@ def set_attention_slice(self, slice_size): if hasattr(block, "attentions") and block.attentions is not None: block.set_attention_slice(slice_size) + def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): + for block in self.down_blocks: + if hasattr(block, "attentions") and block.attentions is not None: + block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) + + self.mid_block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) + + for block in self.up_blocks: + if hasattr(block, "attentions") and block.attentions is not None: + block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) + def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)): module.gradient_checkpointing = value @@ -234,6 +272,7 @@ def forward( sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[UNet2DConditionOutput, Tuple]: r""" @@ -286,6 +325,12 @@ def forward( t_emb = t_emb.to(dtype=self.dtype) emb = self.time_embedding(t_emb) + if self.config.num_class_embeds is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) + emb = emb + class_emb + # 2. pre-process sample = self.conv_in(sample) diff --git a/src/diffusers/models/unet_2d_condition_flax.py b/src/diffusers/models/unet_2d_condition_flax.py index 5411855f79d5..7ca9c191b448 100644 --- a/src/diffusers/models/unet_2d_condition_flax.py +++ b/src/diffusers/models/unet_2d_condition_flax.py @@ -23,7 +23,7 @@ from ..modeling_flax_utils import FlaxModelMixin from ..utils import BaseOutput from .embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps -from .unet_blocks_flax import ( +from .unet_2d_blocks_flax import ( FlaxCrossAttnDownBlock2D, FlaxCrossAttnUpBlock2D, FlaxDownBlock2D, @@ -230,9 +230,9 @@ def __call__( ) -> Union[FlaxUNet2DConditionOutput, Tuple]: r""" Args: - sample (`jnp.ndarray`): (channel, height, width) noisy inputs tensor + sample (`jnp.ndarray`): (batch, channel, height, width) noisy inputs tensor timestep (`jnp.ndarray` or `float` or `int`): timesteps - encoder_hidden_states (`jnp.ndarray`): (channel, height, width) encoder hidden states + encoder_hidden_states (`jnp.ndarray`): (batch_size, sequence_length, hidden_size) encoder hidden states return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a plain tuple. diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index f1c0fc1416ee..30de343d08ee 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -21,7 +21,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..modeling_utils import ModelMixin from ..utils import BaseOutput -from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block +from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block @dataclass @@ -233,14 +233,16 @@ class VectorQuantizer(nn.Module): # NOTE: due to a bug the beta term was applied to the wrong term. for # backwards compatibility we use the buggy version by default, but you can # specify legacy=False to fix it. - def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True): + def __init__( + self, n_e, vq_embed_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True + ): super().__init__() self.n_e = n_e - self.e_dim = e_dim + self.vq_embed_dim = vq_embed_dim self.beta = beta self.legacy = legacy - self.embedding = nn.Embedding(self.n_e, self.e_dim) + self.embedding = nn.Embedding(self.n_e, self.vq_embed_dim) self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) self.remap = remap @@ -287,7 +289,7 @@ def unmap_to_all(self, inds): def forward(self, z): # reshape z -> (batch, height, width, channel) and flatten z = z.permute(0, 2, 3, 1).contiguous() - z_flattened = z.view(-1, self.e_dim) + z_flattened = z.view(-1, self.vq_embed_dim) # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z d = ( @@ -409,6 +411,7 @@ class VQModel(ModelMixin, ConfigMixin): latent_channels (`int`, *optional*, defaults to `3`): Number of channels in the latent space. sample_size (`int`, *optional*, defaults to `32`): TODO num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE. + vq_embed_dim (`int`, *optional*): Hidden dim of codebook vectors in the VQ-VAE. """ @register_to_config @@ -425,6 +428,7 @@ def __init__( sample_size: int = 32, num_vq_embeddings: int = 256, norm_num_groups: int = 32, + vq_embed_dim: Optional[int] = None, ): super().__init__() @@ -440,11 +444,11 @@ def __init__( double_z=False, ) - self.quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1) - self.quantize = VectorQuantizer( - num_vq_embeddings, latent_channels, beta=0.25, remap=None, sane_index_shape=False - ) - self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1) + vq_embed_dim = vq_embed_dim if vq_embed_dim is not None else latent_channels + + self.quant_conv = torch.nn.Conv2d(latent_channels, vq_embed_dim, 1) + self.quantize = VectorQuantizer(num_vq_embeddings, vq_embed_dim, beta=0.25, remap=None, sane_index_shape=False) + self.post_quant_conv = torch.nn.Conv2d(vq_embed_dim, latent_channels, 1) # pass init params to Decoder self.decoder = Decoder( diff --git a/src/diffusers/onnx_utils.py b/src/diffusers/onnx_utils.py index 142174f6e101..b2c533ed741f 100644 --- a/src/diffusers/onnx_utils.py +++ b/src/diffusers/onnx_utils.py @@ -24,7 +24,7 @@ from huggingface_hub import hf_hub_download -from .utils import ONNX_WEIGHTS_NAME, is_onnx_available, logging +from .utils import ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME, is_onnx_available, logging if is_onnx_available(): @@ -33,13 +33,28 @@ logger = logging.get_logger(__name__) +ORT_TO_NP_TYPE = { + "tensor(bool)": np.bool_, + "tensor(int8)": np.int8, + "tensor(uint8)": np.uint8, + "tensor(int16)": np.int16, + "tensor(uint16)": np.uint16, + "tensor(int32)": np.int32, + "tensor(uint32)": np.uint32, + "tensor(int64)": np.int64, + "tensor(uint64)": np.uint64, + "tensor(float16)": np.float16, + "tensor(float)": np.float32, + "tensor(double)": np.float64, +} + class OnnxRuntimeModel: def __init__(self, model=None, **kwargs): logger.info("`diffusers.OnnxRuntimeModel` is experimental and might change in the future.") self.model = model self.model_save_dir = kwargs.get("model_save_dir", None) - self.latest_model_name = kwargs.get("latest_model_name", "model.onnx") + self.latest_model_name = kwargs.get("latest_model_name", ONNX_WEIGHTS_NAME) def __call__(self, **kwargs): inputs = {k: np.array(v) for k, v in kwargs.items()} @@ -84,6 +99,15 @@ def _save_pretrained(self, save_directory: Union[str, Path], file_name: Optional except shutil.SameFileError: pass + # copy external weights (for models >2GB) + src_path = self.model_save_dir.joinpath(ONNX_EXTERNAL_WEIGHTS_NAME) + if src_path.exists(): + dst_path = Path(save_directory).joinpath(ONNX_EXTERNAL_WEIGHTS_NAME) + try: + shutil.copyfile(src_path, dst_path) + except shutil.SameFileError: + pass + def save_pretrained( self, save_directory: Union[str, os.PathLike], diff --git a/src/diffusers/pipeline_flax_utils.py b/src/diffusers/pipeline_flax_utils.py index d55338b50343..bf2e259ea1a7 100644 --- a/src/diffusers/pipeline_flax_utils.py +++ b/src/diffusers/pipeline_flax_utils.py @@ -29,6 +29,7 @@ from tqdm.auto import tqdm from .configuration_utils import ConfigMixin +from .hub_utils import http_user_agent from .modeling_flax_utils import FLAX_WEIGHTS_NAME, FlaxModelMixin from .schedulers.scheduling_utils_flax import SCHEDULER_CONFIG_NAME, FlaxSchedulerMixin from .utils import CONFIG_NAME, DIFFUSERS_CACHE, BaseOutput, is_transformers_available, logging @@ -46,7 +47,7 @@ LOADABLE_CLASSES = { "diffusers": { "FlaxModelMixin": ["save_pretrained", "from_pretrained"], - "FlaxSchedulerMixin": ["save_config", "from_config"], + "FlaxSchedulerMixin": ["save_pretrained", "from_pretrained"], "FlaxDiffusionPipeline": ["save_pretrained", "from_pretrained"], }, "transformers": { @@ -54,6 +55,8 @@ "PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"], "FlaxPreTrainedModel": ["save_pretrained", "from_pretrained"], "FeatureExtractionMixin": ["save_pretrained", "from_pretrained"], + "ProcessorMixin": ["save_pretrained", "from_pretrained"], + "ImageProcessingMixin": ["save_pretrained", "from_pretrained"], }, } @@ -160,6 +163,10 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike], params: Union for pipeline_component_name in model_index_dict.keys(): sub_model = getattr(self, pipeline_component_name) + if sub_model is None: + # edge case for saving a pipeline with safety_checker=None + continue + model_cls = sub_model.__class__ save_method_name = None @@ -167,8 +174,8 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike], params: Union for library_name, library_classes in LOADABLE_CLASSES.items(): library = importlib.import_module(library_name) for base_class, save_load_methods in library_classes.items(): - class_candidate = getattr(library, base_class) - if issubclass(model_cls, class_candidate): + class_candidate = getattr(library, base_class, None) + if class_candidate is not None and issubclass(model_cls, class_candidate): # if we found a suitable base class in LOADABLE_CLASSES then grab its save method save_method_name = save_load_methods[0] break @@ -244,7 +251,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated - models](https://huggingface.co/docs/hub/models-gated#gated-models), *e.g.* `"CompVis/stable-diffusion-v1-4"` + models](https://huggingface.co/docs/hub/models-gated#gated-models), *e.g.* `"runwayml/stable-diffusion-v1-5"` @@ -261,18 +268,27 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P >>> from diffusers import FlaxDiffusionPipeline >>> # Download pipeline from huggingface.co and cache. - >>> pipeline = FlaxDiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256") - - >>> # Download pipeline that requires an authorization token - >>> # For more information on access tokens, please refer to this section - >>> # of the documentation](https://huggingface.co/docs/hub/security-tokens) - >>> pipeline = FlaxDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4") - - >>> # Download pipeline, but overwrite scheduler - >>> from diffusers import LMSDiscreteScheduler - - >>> scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear") - >>> pipeline = FlaxDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", scheduler=scheduler) + >>> # Requires to be logged in to Hugging Face hub, + >>> # see more in [the documentation](https://huggingface.co/docs/hub/security-tokens) + >>> pipeline, params = FlaxDiffusionPipeline.from_pretrained( + ... "runwayml/stable-diffusion-v1-5", + ... revision="bf16", + ... dtype=jnp.bfloat16, + ... ) + + >>> # Download pipeline, but use a different scheduler + >>> from diffusers import FlaxDPMSolverMultistepScheduler + + >>> model_id = "runwayml/stable-diffusion-v1-5" + >>> sched, sched_state = FlaxDPMSolverMultistepScheduler.from_pretrained( + ... model_id, + ... subfolder="scheduler", + ... ) + + >>> dpm_pipe, dpm_params = FlaxStableDiffusionPipeline.from_pretrained( + ... model_id, revision="bf16", dtype=jnp.bfloat16, scheduler=dpmpp + ... ) + >>> dpm_params["scheduler"] = dpmpp_state ``` """ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) @@ -287,7 +303,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # 1. Download the checkpoints and configs # use snapshot download here to get it working from from_pretrained if not os.path.isdir(pretrained_model_name_or_path): - config_dict = cls.get_config_dict( + config_dict = cls.load_config( pretrained_model_name_or_path, cache_dir=cache_dir, resume_download=resume_download, @@ -301,6 +317,22 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P allow_patterns = [os.path.join(k, "*") for k in folder_names] allow_patterns += [FLAX_WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, cls.config_name] + # make sure we don't download PyTorch weights + ignore_patterns = "*.bin" + + if cls != FlaxDiffusionPipeline: + requested_pipeline_class = cls.__name__ + else: + requested_pipeline_class = config_dict.get("_class_name", cls.__name__) + requested_pipeline_class = ( + requested_pipeline_class + if requested_pipeline_class.startswith("Flax") + else "Flax" + requested_pipeline_class + ) + + user_agent = {"pipeline_class": requested_pipeline_class} + user_agent = http_user_agent(user_agent) + # download all allow_patterns cached_folder = snapshot_download( pretrained_model_name_or_path, @@ -311,11 +343,13 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P use_auth_token=use_auth_token, revision=revision, allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + user_agent=user_agent, ) else: cached_folder = pretrained_model_name_or_path - config_dict = cls.get_config_dict(cached_folder) + config_dict = cls.load_config(cached_folder) # 2. Load the pipeline class, if using custom module then load it from the hub # if we load from explicit class, let's use it @@ -328,7 +362,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P if config_dict["_class_name"].startswith("Flax") else "Flax" + config_dict["_class_name"] ) - pipeline_class = getattr(diffusers_module, config_dict["_class_name"]) + pipeline_class = getattr(diffusers_module, class_name) # some modules can be passed directly to the init # in this case they are already instantiated in `kwargs` @@ -336,7 +370,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P expected_modules = set(inspect.signature(pipeline_class.__init__).parameters.keys()) passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs} - init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) + init_dict, _, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) init_kwargs = {} @@ -348,6 +382,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # 3. Load each module in the pipeline for name, (library_name, class_name) in init_dict.items(): + if class_name is None: + # edge case for when the pipeline was saved with safety_checker=None + init_kwargs[name] = None + continue + is_pipeline_module = hasattr(pipelines, library_name) loaded_sub_model = None sub_model_should_be_defined = True @@ -359,11 +398,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P library = importlib.import_module(library_name) class_obj = getattr(library, class_name) importable_classes = LOADABLE_CLASSES[library_name] - class_candidates = {c: getattr(library, c) for c in importable_classes.keys()} + class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()} expected_class_obj = None for class_name, class_candidate in class_candidates.items(): - if issubclass(class_obj, class_candidate): + if class_candidate is not None and issubclass(class_obj, class_candidate): expected_class_obj = class_candidate if not issubclass(passed_class_obj[name].__class__, expected_class_obj): @@ -372,13 +411,13 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P f" {expected_class_obj}" ) elif passed_class_obj[name] is None: - logger.warn( + logger.warning( f"You have passed `None` for {name} to disable its functionality in {pipeline_class}. Note" f" that this might lead to problems when using {pipeline_class} and is not recommended." ) sub_model_should_be_defined = False else: - logger.warn( + logger.warning( f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it" " has the correct type" ) @@ -397,12 +436,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P class_obj = import_flax_or_no_model(library, class_name) importable_classes = LOADABLE_CLASSES[library_name] - class_candidates = {c: getattr(library, c) for c in importable_classes.keys()} + class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()} if loaded_sub_model is None and sub_model_should_be_defined: load_method_name = None for class_name, class_candidate in class_candidates.items(): - if issubclass(class_obj, class_candidate): + if class_candidate is not None and issubclass(class_obj, class_candidate): load_method_name = importable_classes[class_name][1] load_method = getattr(class_obj, load_method_name) @@ -444,7 +483,11 @@ def numpy_to_pil(images): if images.ndim == 3: images = images[None, ...] images = (images * 255).round().astype("uint8") - pil_images = [Image.fromarray(image) for image in images] + if images.shape[-1] == 1: + # special case for grayscale (single channel) images + pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] + else: + pil_images = [Image.fromarray(image) for image in images] return pil_images diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index d3e6113dac8e..35ebd536c511 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -18,7 +18,8 @@ import inspect import os from dataclasses import dataclass -from typing import List, Optional, Union +from pathlib import Path +from typing import Any, Dict, List, Optional, Union import numpy as np import torch @@ -30,9 +31,10 @@ from PIL import Image from tqdm.auto import tqdm -from . import __version__ from .configuration_utils import ConfigMixin from .dynamic_modules_utils import get_class_from_dynamic_module +from .hub_utils import http_user_agent +from .modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME from .utils import ( CONFIG_NAME, @@ -40,6 +42,9 @@ ONNX_WEIGHTS_NAME, WEIGHTS_NAME, BaseOutput, + deprecate, + is_accelerate_available, + is_torch_version, is_transformers_available, logging, ) @@ -53,6 +58,7 @@ INDEX_FILE = "diffusion_pytorch_model.bin" CUSTOM_PIPELINE_FILE_NAME = "pipeline.py" DUMMY_MODULES_FOLDER = "diffusers.utils" +TRANSFORMERS_DUMMY_MODULES_FOLDER = "transformers.utils" logger = logging.get_logger(__name__) @@ -61,7 +67,7 @@ LOADABLE_CLASSES = { "diffusers": { "ModelMixin": ["save_pretrained", "from_pretrained"], - "SchedulerMixin": ["save_config", "from_config"], + "SchedulerMixin": ["save_pretrained", "from_pretrained"], "DiffusionPipeline": ["save_pretrained", "from_pretrained"], "OnnxRuntimeModel": ["save_pretrained", "from_pretrained"], }, @@ -70,6 +76,11 @@ "PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"], "PreTrainedModel": ["save_pretrained", "from_pretrained"], "FeatureExtractionMixin": ["save_pretrained", "from_pretrained"], + "ProcessorMixin": ["save_pretrained", "from_pretrained"], + "ImageProcessingMixin": ["save_pretrained", "from_pretrained"], + }, + "onnxruntime.training": { + "ORTModule": ["save_pretrained", "from_pretrained"], }, } @@ -92,6 +103,20 @@ class ImagePipelineOutput(BaseOutput): images: Union[List[PIL.Image.Image], np.ndarray] +@dataclass +class AudioPipelineOutput(BaseOutput): + """ + Output class for audio pipelines. + + Args: + audios (`np.ndarray`) + List of denoised samples of shape `(batch_size, num_channels, sample_rate)`. Numpy array present the + denoised audio samples of the diffusion pipeline. + """ + + audios: np.ndarray + + class DiffusionPipeline(ConfigMixin): r""" Base class for all models. @@ -104,10 +129,13 @@ class DiffusionPipeline(ConfigMixin): Class attributes: - - **config_name** ([`str`]) -- name of the config file that will store the class and module names of all + - **config_name** (`str`) -- name of the config file that will store the class and module names of all components of the diffusion pipeline. + - **_optional_components** (List[`str`]) -- list of all components that are optional so they don't have to be + passed for the pipeline to function (should be overridden by subclasses). """ config_name = "model_index.json" + _optional_components = [] def register_modules(self, **kwargs): # import it here to avoid circular import @@ -121,7 +149,7 @@ def register_modules(self, **kwargs): library = module.__module__.split(".")[0] # check if the module is a pipeline module - pipeline_dir = module.__module__.split(".")[-2] + pipeline_dir = module.__module__.split(".")[-2] if len(module.__module__.split(".")) > 2 else None path = module.__module__.split(".") is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir) @@ -159,6 +187,17 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike]): model_index_dict.pop("_diffusers_version") model_index_dict.pop("_module", None) + expected_modules, optional_kwargs = self._get_signature_keys(self) + + def is_saveable_module(name, value): + if name not in expected_modules: + return False + if name in self._optional_components and value[0] is None: + return False + return True + + model_index_dict = {k: v for k, v in model_index_dict.items() if is_saveable_module(k, v)} + for pipeline_component_name in model_index_dict.keys(): sub_model = getattr(self, pipeline_component_name) model_cls = sub_model.__class__ @@ -168,8 +207,8 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike]): for library_name, library_classes in LOADABLE_CLASSES.items(): library = importlib.import_module(library_name) for base_class, save_load_methods in library_classes.items(): - class_candidate = getattr(library, base_class) - if issubclass(model_cls, class_candidate): + class_candidate = getattr(library, base_class, None) + if class_candidate is not None and issubclass(model_cls, class_candidate): # if we found a suitable base class in LOADABLE_CLASSES then grab its save method save_method_name = save_load_methods[0] break @@ -183,17 +222,17 @@ def to(self, torch_device: Optional[Union[str, torch.device]] = None): if torch_device is None: return self - module_names, _ = self.extract_init_dict(dict(self.config)) + module_names, _, _ = self.extract_init_dict(dict(self.config)) for name in module_names.keys(): module = getattr(self, name) if isinstance(module, torch.nn.Module): - if module.dtype == torch.float16 and str(torch_device) in ["cpu", "mps"]: + if module.dtype == torch.float16 and str(torch_device) in ["cpu"]: logger.warning( - "Pipelines loaded with `torch_dtype=torch.float16` cannot run with `cpu` or `mps` device. It" - " is not recommended to move them to `cpu` or `mps` as running them will fail. Please make" - " sure to use a `cuda` device to run the pipeline in inference. due to the lack of support for" - " `float16` operations on those devices in PyTorch. Please remove the" - " `torch_dtype=torch.float16` argument, or use a `cuda` device to run inference." + "Pipelines loaded with `torch_dtype=torch.float16` cannot run with `cpu` device. It" + " is not recommended to move them to `cpu` as running them will fail. Please make" + " sure to use an accelerator to run the pipeline in inference, due to the lack of" + " support for`float16` operations on this device in PyTorch. Please, remove the" + " `torch_dtype=torch.float16` argument, or use another device for inference." ) module.to(torch_device) return self @@ -204,7 +243,7 @@ def device(self) -> torch.device: Returns: `torch.device`: The torch device on which the pipeline is located. """ - module_names, _ = self.extract_init_dict(dict(self.config)) + module_names, _, _ = self.extract_init_dict(dict(self.config)) for name in module_names.keys(): module = getattr(self, name) if isinstance(module, torch.nn.Module): @@ -279,8 +318,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P For more information on how to load and create custom pipelines, please have a look at [Loading and - Creating Custom - Pipelines](https://huggingface.co/docs/diffusers/main/en/using-diffusers/custom_pipelines) + Adding Custom + Pipelines](https://huggingface.co/docs/diffusers/using-diffusers/custom_pipeline_overview) torch_dtype (`str` or `torch.dtype`, *optional*): force_download (`bool`, *optional*, defaults to `False`): @@ -307,6 +346,19 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P Mirror source to accelerate downloads in China. If you are from China and have an accessibility problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. Please refer to the mirror site for more information. specify the folder name here. + device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): + A map that specifies where each submodule should go. It doesn't need to be refined to each + parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the + same device. + + To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For + more information about each option see [designing a device + map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). + low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): + Speed up model loading by not initializing the weights and only loading the pre-trained weights. This + also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the + model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch, + setting this argument to `True` will raise an error. kwargs (remaining dictionary of keyword arguments, *optional*): Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the @@ -316,7 +368,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated - models](https://huggingface.co/docs/hub/models-gated#gated-models), *e.g.* `"CompVis/stable-diffusion-v1-4"` + models](https://huggingface.co/docs/hub/models-gated#gated-models), *e.g.* `"runwayml/stable-diffusion-v1-5"` @@ -338,17 +390,18 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P >>> # Download pipeline that requires an authorization token >>> # For more information on access tokens, please refer to this section >>> # of the documentation](https://huggingface.co/docs/hub/security-tokens) - >>> pipeline = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4") + >>> pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") - >>> # Download pipeline, but overwrite scheduler + >>> # Use a different scheduler >>> from diffusers import LMSDiscreteScheduler - >>> scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear") - >>> pipeline = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", scheduler=scheduler) + >>> scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config) + >>> pipeline.scheduler = scheduler ``` """ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) resume_download = kwargs.pop("resume_download", False) + force_download = kwargs.pop("force_download", False) proxies = kwargs.pop("proxies", None) local_files_only = kwargs.pop("local_files_only", False) use_auth_token = kwargs.pop("use_auth_token", None) @@ -358,14 +411,43 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P provider = kwargs.pop("provider", None) sess_options = kwargs.pop("sess_options", None) device_map = kwargs.pop("device_map", None) + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) + + if low_cpu_mem_usage and not is_accelerate_available(): + low_cpu_mem_usage = False + logger.warning( + "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" + " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" + " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" + " install accelerate\n```\n." + ) + + if device_map is not None and not is_torch_version(">=", "1.9.0"): + raise NotImplementedError( + "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set" + " `device_map=None`." + ) + + if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): + raise NotImplementedError( + "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" + " `low_cpu_mem_usage=False`." + ) + + if low_cpu_mem_usage is False and device_map is not None: + raise ValueError( + f"You cannot set `low_cpu_mem_usage` to False while using device_map={device_map} for loading and" + " dispatching. Please make sure to set `low_cpu_mem_usage=True`." + ) # 1. Download the checkpoints and configs # use snapshot download here to get it working from from_pretrained if not os.path.isdir(pretrained_model_name_or_path): - config_dict = cls.get_config_dict( + config_dict = cls.load_config( pretrained_model_name_or_path, cache_dir=cache_dir, resume_download=resume_download, + force_download=force_download, proxies=proxies, local_files_only=local_files_only, use_auth_token=use_auth_token, @@ -376,13 +458,20 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P allow_patterns = [os.path.join(k, "*") for k in folder_names] allow_patterns += [WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, ONNX_WEIGHTS_NAME, cls.config_name] + # make sure we don't download flax weights + ignore_patterns = "*.msgpack" + if custom_pipeline is not None: allow_patterns += [CUSTOM_PIPELINE_FILE_NAME] - requested_pipeline_class = config_dict.get("_class_name", cls.__name__) - user_agent = {"diffusers": __version__, "pipeline_class": requested_pipeline_class} + if cls != DiffusionPipeline: + requested_pipeline_class = cls.__name__ + else: + requested_pipeline_class = config_dict.get("_class_name", cls.__name__) + user_agent = {"pipeline_class": requested_pipeline_class} if custom_pipeline is not None: user_agent["custom_pipeline"] = custom_pipeline + user_agent = http_user_agent(user_agent) # download all allow_patterns cached_folder = snapshot_download( @@ -394,18 +483,27 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P use_auth_token=use_auth_token, revision=revision, allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, user_agent=user_agent, ) else: cached_folder = pretrained_model_name_or_path - config_dict = cls.get_config_dict(cached_folder) + config_dict = cls.load_config(cached_folder) # 2. Load the pipeline class, if using custom module then load it from the hub # if we load from explicit class, let's use it if custom_pipeline is not None: + if custom_pipeline.endswith(".py"): + path = Path(custom_pipeline) + # decompose into folder & file + file_name = path.name + custom_pipeline = path.parent.absolute() + else: + file_name = CUSTOM_PIPELINE_FILE_NAME + pipeline_class = get_class_from_dynamic_module( - custom_pipeline, module_file=CUSTOM_PIPELINE_FILE_NAME, cache_dir=custom_pipeline + custom_pipeline, module_file=file_name, cache_dir=custom_pipeline ) elif cls != DiffusionPipeline: pipeline_class = cls @@ -413,15 +511,52 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P diffusers_module = importlib.import_module(cls.__module__.split(".")[0]) pipeline_class = getattr(diffusers_module, config_dict["_class_name"]) + # To be removed in 1.0.0 + if pipeline_class.__name__ == "StableDiffusionInpaintPipeline" and version.parse( + version.parse(config_dict["_diffusers_version"]).base_version + ) <= version.parse("0.5.1"): + from diffusers import StableDiffusionInpaintPipeline, StableDiffusionInpaintPipelineLegacy + + pipeline_class = StableDiffusionInpaintPipelineLegacy + + deprecation_message = ( + "You are using a legacy checkpoint for inpainting with Stable Diffusion, therefore we are loading the" + f" {StableDiffusionInpaintPipelineLegacy} class instead of {StableDiffusionInpaintPipeline}. For" + " better inpainting results, we strongly suggest using Stable Diffusion's official inpainting" + " checkpoint: https://huggingface.co/runwayml/stable-diffusion-inpainting instead or adapting your" + f" checkpoint {pretrained_model_name_or_path} to the format of" + " https://huggingface.co/runwayml/stable-diffusion-inpainting. Note that we do not actively maintain" + " the {StableDiffusionInpaintPipelineLegacy} class and will likely remove it in version 1.0.0." + ) + deprecate("StableDiffusionInpaintPipelineLegacy", "1.0.0", deprecation_message, standard_warn=False) + # some modules can be passed directly to the init # in this case they are already instantiated in `kwargs` # extract them here - expected_modules = set(inspect.signature(pipeline_class.__init__).parameters.keys()) - set(["self"]) + expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class) passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs} + passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs} + + init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) - init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) + # define init kwargs + init_kwargs = {k: init_dict.pop(k) for k in optional_kwargs if k in init_dict} + init_kwargs = {**init_kwargs, **passed_pipe_kwargs} - init_kwargs = {} + # remove `null` components + def load_module(name, value): + if value[0] is None: + return False + if name in passed_class_obj and passed_class_obj[name] is None: + return False + return True + + init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)} + + if len(unused_kwargs) > 0: + logger.warning( + f"Keyword arguments {unused_kwargs} are not expected by {pipeline_class.__name__} and will be ignored." + ) # import it here to avoid circular import from diffusers import pipelines @@ -434,7 +569,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P is_pipeline_module = hasattr(pipelines, library_name) loaded_sub_model = None - sub_model_should_be_defined = True # if the model is in a pipeline module, then we load it from the pipeline if name in passed_class_obj: @@ -443,11 +577,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P library = importlib.import_module(library_name) class_obj = getattr(library, class_name) importable_classes = LOADABLE_CLASSES[library_name] - class_candidates = {c: getattr(library, c) for c in importable_classes.keys()} + class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()} expected_class_obj = None for class_name, class_candidate in class_candidates.items(): - if issubclass(class_obj, class_candidate): + if class_candidate is not None and issubclass(class_obj, class_candidate): expected_class_obj = class_candidate if not issubclass(passed_class_obj[name].__class__, expected_class_obj): @@ -455,14 +589,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be" f" {expected_class_obj}" ) - elif passed_class_obj[name] is None: - logger.warn( - f"You have passed `None` for {name} to disable its functionality in {pipeline_class}. Note" - f" that this might lead to problems when using {pipeline_class} and is not recommended." - ) - sub_model_should_be_defined = False else: - logger.warn( + logger.warning( f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it" " has the correct type" ) @@ -477,19 +605,23 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P else: # else we just import it from the library. library = importlib.import_module(library_name) + class_obj = getattr(library, class_name) importable_classes = LOADABLE_CLASSES[library_name] - class_candidates = {c: getattr(library, c) for c in importable_classes.keys()} + class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()} - if loaded_sub_model is None and sub_model_should_be_defined: + if loaded_sub_model is None: load_method_name = None for class_name, class_candidate in class_candidates.items(): - if issubclass(class_obj, class_candidate): + if class_candidate is not None and issubclass(class_obj, class_candidate): load_method_name = importable_classes[class_name][1] if load_method_name is None: none_module = class_obj.__module__ - if none_module.startswith(DUMMY_MODULES_FOLDER) and "dummy" in none_module: + is_dummy_path = none_module.startswith(DUMMY_MODULES_FOLDER) or none_module.startswith( + TRANSFORMERS_DUMMY_MODULES_FOLDER + ) + if is_dummy_path and "dummy" in none_module: # call class_obj for nice error message of missing requirements class_obj() @@ -514,8 +646,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P and version.parse(version.parse(transformers.__version__).base_version) >= version.parse("4.20.0") ) + # When loading a transformers model, if the device_map is None, the weights will be initialized as opposed to diffusers. + # To make default loading faster we set the `low_cpu_mem_usage=low_cpu_mem_usage` flag which is `True` by default. + # This makes sure that the weights won't be initialized which significantly speeds up loading. if is_diffusers_model or is_transformers_model: loading_kwargs["device_map"] = device_map + loading_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage # check if the module is in a subdirectory if os.path.isdir(os.path.join(cached_folder, name)): @@ -528,11 +664,13 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # 4. Potentially add passed objects if expected missing_modules = set(expected_modules) - set(init_kwargs.keys()) - if len(missing_modules) > 0 and missing_modules <= set(passed_class_obj.keys()): + passed_modules = list(passed_class_obj.keys()) + optional_modules = pipeline_class._optional_components + if len(missing_modules) > 0 and missing_modules <= set(passed_modules + optional_modules): for module in missing_modules: - init_kwargs[module] = passed_class_obj[module] + init_kwargs[module] = passed_class_obj.get(module, None) elif len(missing_modules) > 0: - passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) + passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) - optional_kwargs raise ValueError( f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed." ) @@ -541,6 +679,51 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P model = pipeline_class(**init_kwargs) return model + @staticmethod + def _get_signature_keys(obj): + parameters = inspect.signature(obj.__init__).parameters + required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty} + optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty}) + expected_modules = set(required_parameters.keys()) - set(["self"]) + return expected_modules, optional_parameters + + @property + def components(self) -> Dict[str, Any]: + r""" + + The `self.components` property can be useful to run different pipelines with the same weights and + configurations to not have to re-allocate memory. + + Examples: + + ```py + >>> from diffusers import ( + ... StableDiffusionPipeline, + ... StableDiffusionImg2ImgPipeline, + ... StableDiffusionInpaintPipeline, + ... ) + + >>> text2img = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") + >>> img2img = StableDiffusionImg2ImgPipeline(**text2img.components) + >>> inpaint = StableDiffusionInpaintPipeline(**text2img.components) + ``` + + Returns: + A dictionaly containing all the modules needed to initialize the pipeline. + """ + expected_modules, optional_parameters = self._get_signature_keys(self) + components = { + k: getattr(self, k) for k in self.config.keys() if not k.startswith("_") and k not in optional_parameters + } + + if set(components.keys()) != expected_modules: + raise ValueError( + f"{self} has been incorrectly initialized or {self.__class__} is incorrectly implemented. Expected" + f" {expected_modules} to be defined, but {components} are defined." + ) + + return components + @staticmethod def numpy_to_pil(images): """ @@ -549,7 +732,11 @@ def numpy_to_pil(images): if images.ndim == 3: images = images[None, ...] images = (images * 255).round().astype("uint8") - pil_images = [Image.fromarray(image) for image in images] + if images.shape[-1] == 1: + # special case for grayscale (single channel) images + pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] + else: + pil_images = [Image.fromarray(image) for image in images] return pil_images diff --git a/src/diffusers/pipelines/README.md b/src/diffusers/pipelines/README.md index b5ea112feafc..6ff40d3549b4 100644 --- a/src/diffusers/pipelines/README.md +++ b/src/diffusers/pipelines/README.md @@ -16,7 +16,7 @@ or created independently from each other. To that end, we strive to offer all open-sourced, state-of-the-art diffusion system under a unified API. More specifically, we strive to provide pipelines that -- 1. can load the officially published weights and yield 1-to-1 the same outputs as the original implementation according to the corresponding paper (*e.g.* [LatentDiffusionPipeline](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines/latent_diffusion), uses the officially released weights of [High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752)), +- 1. can load the officially published weights and yield 1-to-1 the same outputs as the original implementation according to the corresponding paper (*e.g.* [LDMTextToImagePipeline](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines/latent_diffusion), uses the officially released weights of [High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752)), - 2. have a simple user interface to run the model in inference (see the [Pipelines API](#pipelines-api) section), - 3. are easy to understand with code that is self-explanatory and can be read along-side the official paper (see [Pipelines summary](#pipelines-summary)), - 4. can easily be contributed by the community (see the [Contribution](#contribution) section). @@ -30,19 +30,20 @@ If you are looking for *official* training examples, please have a look at [exam The following table summarizes all officially supported pipelines, their corresponding paper, and if available a colab notebook to directly try them out. -| Pipeline | Paper | Tasks | Colab -|---|---|:---:|:---:| -| [ddpm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | *Unconditional Image Generation* | -| [ddim](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/ddim) | [**Denoising Diffusion Implicit Models**](https://arxiv.org/abs/2010.02502) | *Unconditional Image Generation* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb) -| [latent_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| *Text-to-Image Generation* | -| [latent_diffusion_uncond](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/latent_diffusion_uncond) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752) | *Unconditional Image Generation* | -| [pndm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pndm) | [**Pseudo Numerical Methods for Diffusion Models on Manifolds**](https://arxiv.org/abs/2202.09778) | *Unconditional Image Generation* | -| [score_sde_ve](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/score_sde_ve) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | *Unconditional Image Generation* | -| [score_sde_vp](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/score_sde_vp) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | *Unconditional Image Generation* | -| [stable_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | *Text-to-Image Generation* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb) -| [stable_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | *Image-to-Image Text-Guided Generation* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb) -| [stable_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | *Text-Guided Image Inpainting* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb) -| [stochastic_karras_ve](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | *Unconditional Image Generation* | +| Pipeline | Source | Tasks | Colab +|-------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------|:---:|:---:| +| [dance diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/dance_diffusion) | [**Dance Diffusion**](https://github.com/Harmonai-org/sample-generator) | *Unconditional Audio Generation* | +| [ddpm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | *Unconditional Image Generation* | +| [ddim](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/ddim) | [**Denoising Diffusion Implicit Models**](https://arxiv.org/abs/2010.02502) | *Unconditional Image Generation* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb) +| [latent_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752) | *Text-to-Image Generation* | +| [latent_diffusion_uncond](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/latent_diffusion_uncond) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752) | *Unconditional Image Generation* | +| [pndm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pndm) | [**Pseudo Numerical Methods for Diffusion Models on Manifolds**](https://arxiv.org/abs/2202.09778) | *Unconditional Image Generation* | +| [score_sde_ve](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/score_sde_ve) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | *Unconditional Image Generation* | +| [score_sde_vp](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/score_sde_vp) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | *Unconditional Image Generation* | +| [stable_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | *Text-to-Image Generation* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_diffusion.ipynb) +| [stable_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | *Image-to-Image Text-Guided Generation* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb) +| [stable_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | *Text-Guided Image Inpainting* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb) +| [stochastic_karras_ve](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | *Unconditional Image Generation* | **Note**: Pipelines are simple examples of how to play around with the diffusion systems as described in the corresponding papers. However, most of them can be adapted to use different scheduler components or even different model components. Some pipeline examples are shown in the [Examples](#examples) below. @@ -55,8 +56,8 @@ Diffusion models often consist of multiple independently-trained models or other Each model has been trained independently on a different task and the scheduler can easily be swapped out and replaced with a different one. During inference, we however want to be able to easily load all components and use them in inference - even if one component, *e.g.* CLIP's text encoder, originates from a different library, such as [Transformers](https://github.com/huggingface/transformers). To that end, all pipelines provide the following functionality: -- [`from_pretrained` method](https://github.com/huggingface/diffusers/blob/5cbed8e0d157f65d3ddc2420dfd09f2df630e978/src/diffusers/pipeline_utils.py#L139) that accepts a Hugging Face Hub repository id, *e.g.* [CompVis/stable-diffusion-v1-4](https://huggingface.co/CompVis/stable-diffusion-v1-4) or a path to a local directory, *e.g.* -"./stable-diffusion". To correctly retrieve which models and components should be loaded, one has to provide a `model_index.json` file, *e.g.* [CompVis/stable-diffusion-v1-4/model_index.json](https://huggingface.co/CompVis/stable-diffusion-v1-4/blob/main/model_index.json), which defines all components that should be +- [`from_pretrained` method](https://github.com/huggingface/diffusers/blob/5cbed8e0d157f65d3ddc2420dfd09f2df630e978/src/diffusers/pipeline_utils.py#L139) that accepts a Hugging Face Hub repository id, *e.g.* [runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) or a path to a local directory, *e.g.* +"./stable-diffusion". To correctly retrieve which models and components should be loaded, one has to provide a `model_index.json` file, *e.g.* [runwayml/stable-diffusion-v1-5/model_index.json](https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/model_index.json), which defines all components that should be loaded into the pipelines. More specifically, for each model/component one needs to define the format `: ["", ""]`. `` is the attribute name given to the loaded instance of `` which can be found in the library or pipeline folder called `""`. - [`save_pretrained`](https://github.com/huggingface/diffusers/blob/5cbed8e0d157f65d3ddc2420dfd09f2df630e978/src/diffusers/pipeline_utils.py#L90) that accepts a local path, *e.g.* `./stable-diffusion` under which all models/components of the pipeline will be saved. For each component/model a folder is created inside the local path that is named after the given attribute name, *e.g.* `./stable_diffusion/unet`. In addition, a `model_index.json` file is created at the root of the local path, *e.g.* `./stable_diffusion/model_index.json` so that the complete pipeline can again be instantiated @@ -88,7 +89,7 @@ logic including pre-processing, an unrolled diffusion loop, and post-processing # make sure you're logged in with `huggingface-cli login` from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler -pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4") +pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") pipe = pipe.to("cuda") prompt = "a photo of an astronaut riding a horse on mars" @@ -111,7 +112,7 @@ from diffusers import StableDiffusionImg2ImgPipeline # load the pipeline device = "cuda" pipe = StableDiffusionImg2ImgPipeline.from_pretrained( - "CompVis/stable-diffusion-v1-4", + "runwayml/stable-diffusion-v1-5", revision="fp16", torch_dtype=torch.float16, ).to(device) @@ -141,10 +142,10 @@ You can generate your own latents to reproduce results, or tweak your prompt on The `StableDiffusionInpaintPipeline` lets you edit specific parts of an image by providing a mask and text prompt. ```python -from io import BytesIO - -import requests import PIL +import requests +import torch +from io import BytesIO from diffusers import StableDiffusionInpaintPipeline @@ -158,17 +159,15 @@ mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data init_image = download_image(img_url).resize((512, 512)) mask_image = download_image(mask_url).resize((512, 512)) -device = "cuda" pipe = StableDiffusionInpaintPipeline.from_pretrained( - "CompVis/stable-diffusion-v1-4", - revision="fp16", + "runwayml/stable-diffusion-inpainting", + revision="fp16", torch_dtype=torch.float16, -).to(device) - -prompt = "a cat sitting on a bench" -images = pipe(prompt=prompt, init_image=init_image, mask_image=mask_image, strength=0.75).images +) +pipe = pipe.to("cuda") -images[0].save("cat_on_bench.png") +prompt = "Face of a yellow cat, high resolution, sitting on a park bench" +image = pipe(prompt=prompt, image=init_image, mask_image=mask_image).images[0] ``` You can also run this example on colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 1c31595fb0cf..c5aba302042b 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -2,25 +2,47 @@ if is_torch_available(): + from .dance_diffusion import DanceDiffusionPipeline from .ddim import DDIMPipeline from .ddpm import DDPMPipeline + from .latent_diffusion import LDMSuperResolutionPipeline from .latent_diffusion_uncond import LDMPipeline from .pndm import PNDMPipeline + from .repaint import RePaintPipeline from .score_sde_ve import ScoreSdeVePipeline from .stochastic_karras_ve import KarrasVePipeline else: from ..utils.dummy_pt_objects import * # noqa F403 if is_torch_available() and is_transformers_available(): + from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline from .latent_diffusion import LDMTextToImagePipeline from .stable_diffusion import ( + CycleDiffusionPipeline, + StableDiffusionImageVariationPipeline, StableDiffusionImg2ImgPipeline, StableDiffusionInpaintPipeline, + StableDiffusionInpaintPipelineLegacy, StableDiffusionPipeline, + StableDiffusionUpscalePipeline, ) + from .stable_diffusion_safe import StableDiffusionPipelineSafe + from .versatile_diffusion import ( + VersatileDiffusionDualGuidedPipeline, + VersatileDiffusionImageVariationPipeline, + VersatileDiffusionPipeline, + VersatileDiffusionTextToImagePipeline, + ) + from .vq_diffusion import VQDiffusionPipeline if is_transformers_available() and is_onnx_available(): - from .stable_diffusion import StableDiffusionOnnxPipeline + from .stable_diffusion import ( + OnnxStableDiffusionImg2ImgPipeline, + OnnxStableDiffusionInpaintPipeline, + OnnxStableDiffusionInpaintPipelineLegacy, + OnnxStableDiffusionPipeline, + StableDiffusionOnnxPipeline, + ) if is_transformers_available() and is_flax_available(): from .stable_diffusion import FlaxStableDiffusionPipeline diff --git a/src/diffusers/pipelines/alt_diffusion/__init__.py b/src/diffusers/pipelines/alt_diffusion/__init__.py new file mode 100644 index 000000000000..09d0d9b7852c --- /dev/null +++ b/src/diffusers/pipelines/alt_diffusion/__init__.py @@ -0,0 +1,34 @@ +from dataclasses import dataclass +from typing import List, Optional, Union + +import numpy as np + +import PIL +from PIL import Image + +from ...utils import BaseOutput, is_torch_available, is_transformers_available + + +@dataclass +# Copied from diffusers.pipelines.stable_diffusion.__init__.StableDiffusionPipelineOutput with Stable->Alt +class AltDiffusionPipelineOutput(BaseOutput): + """ + Output class for Alt Diffusion pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + nsfw_content_detected (`List[bool]`) + List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, or `None` if safety checking could not be performed. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] + nsfw_content_detected: Optional[List[bool]] + + +if is_transformers_available() and is_torch_available(): + from .modeling_roberta_series import RobertaSeriesModelWithTransformation + from .pipeline_alt_diffusion import AltDiffusionPipeline + from .pipeline_alt_diffusion_img2img import AltDiffusionImg2ImgPipeline diff --git a/src/diffusers/pipelines/alt_diffusion/modeling_roberta_series.py b/src/diffusers/pipelines/alt_diffusion/modeling_roberta_series.py new file mode 100644 index 000000000000..2e92314162d3 --- /dev/null +++ b/src/diffusers/pipelines/alt_diffusion/modeling_roberta_series.py @@ -0,0 +1,110 @@ +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +from torch import nn + +from transformers import RobertaPreTrainedModel, XLMRobertaConfig, XLMRobertaModel +from transformers.utils import ModelOutput + + +@dataclass +class TransformationModelOutput(ModelOutput): + """ + Base class for text model's outputs that also contains a pooling of the last hidden states. + + Args: + text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The text embeddings obtained by applying the projection layer to the pooler_output. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + projection_state: Optional[torch.FloatTensor] = None + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +class RobertaSeriesConfig(XLMRobertaConfig): + def __init__( + self, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + project_dim=512, + pooler_fn="cls", + learn_encoder=False, + use_attention_mask=True, + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + self.project_dim = project_dim + self.pooler_fn = pooler_fn + self.learn_encoder = learn_encoder + self.use_attention_mask = use_attention_mask + + +class RobertaSeriesModelWithTransformation(RobertaPreTrainedModel): + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + base_model_prefix = "roberta" + config_class = RobertaSeriesConfig + + def __init__(self, config): + super().__init__(config) + self.roberta = XLMRobertaModel(config) + self.transformation = nn.Linear(config.hidden_size, config.project_dim) + self.post_init() + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ): + r""" """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.base_model( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + projection_state = self.transformation(outputs.last_hidden_state) + + return TransformationModelOutput( + projection_state=projection_state, + last_hidden_state=outputs.last_hidden_state, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py new file mode 100644 index 000000000000..3bbc3b3fd7ff --- /dev/null +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -0,0 +1,578 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable, List, Optional, Union + +import torch + +from diffusers.utils import is_accelerate_available +from packaging import version +from transformers import CLIPFeatureExtractor, XLMRobertaTokenizer + +from ...configuration_utils import FrozenDict +from ...models import AutoencoderKL, UNet2DConditionModel +from ...pipeline_utils import DiffusionPipeline +from ...schedulers import ( + DDIMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, +) +from ...utils import deprecate, logging +from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from . import AltDiffusionPipelineOutput, RobertaSeriesModelWithTransformation + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker +class AltDiffusionPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using Alt Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`RobertaSeriesModelWithTransformation`]): + Frozen text-encoder. Alt Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.RobertaSeriesModelWithTransformation), + specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`XLMRobertaTokenizer`): + Tokenizer of class + [XLMRobertaTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.XLMRobertaTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: RobertaSeriesModelWithTransformation, + tokenizer: XLMRobertaTokenizer, + unet: UNet2DConditionModel, + scheduler: Union[ + DDIMScheduler, + PNDMScheduler, + LMSDiscreteScheduler, + EulerDiscreteScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + ], + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPFeatureExtractor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Alt Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + def enable_xformers_memory_efficient_attention(self): + r""" + Enable memory efficient attention as implemented in xformers. + + When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference + time. Speed up at training time is not guaranteed. + + Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention + is used. + """ + self.unet.set_use_memory_efficient_attention_xformers(True) + + def disable_xformers_memory_efficient_attention(self): + r""" + Disable memory efficient attention as implemented in xformers. + """ + self.unet.set_use_memory_efficient_attention_xformers(False) + + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, + `attention_head_dim` must be a multiple of `slice_size`. + """ + if slice_size == "auto": + if isinstance(self.unet.config.attention_head_dim, int): + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = self.unet.config.attention_head_dim // 2 + else: + # if `attention_head_dim` is a list, take the smallest head size + slice_size = min(self.unet.config.attention_head_dim) + + self.unet.set_attention_slice(slice_size) + + def disable_attention_slicing(self): + r""" + Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go + back to computing attention in one step. + """ + # set slice_size = `None` to disable `attention slicing` + self.enable_attention_slicing(None) + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + if self.safety_checker is not None: + # TODO(Patrick) - there is currently a bug with cpu offload of nn.Parameter in accelerate + # fix by only offloading self.safety_checker for now + cpu_offload(self.safety_checker.vision_model, device) + + @property + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list(int)`): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + """ + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids + + if not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + text_embeddings = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + text_embeddings = text_embeddings[0] + + # duplicate text embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = text_embeddings.shape + text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) + text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + uncond_embeddings = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + uncond_embeddings = uncond_embeddings[0] + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = uncond_embeddings.shape[1] + uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) + uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + return text_embeddings + + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + else: + has_nsfw_concept = None + return image, has_nsfw_concept + + def decode_latents(self, latents): + latents = 1 / 0.18215 * latents + image = self.vae.decode(latents).sample + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs(self, prompt, height, width, callback_steps): + if not isinstance(prompt, str) and not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if latents is None: + if device.type == "mps": + # randn does not work reproducibly on mps + latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device) + else: + latents = torch.randn(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Returns: + [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, height, width, callback_steps) + + # 2. Define call parameters + batch_size = 1 if isinstance(prompt, str) else len(prompt) + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_embeddings = self._encode_prompt( + prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + text_embeddings.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # 8. Post-processing + image = self.decode_latents(latents) + + # 9. Run safety checker + image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype) + + # 10. Convert to PIL + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image, has_nsfw_concept) + + return AltDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py new file mode 100644 index 000000000000..23b4b42b5899 --- /dev/null +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py @@ -0,0 +1,611 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable, List, Optional, Union + +import numpy as np +import torch + +import PIL +from diffusers.utils import is_accelerate_available +from packaging import version +from transformers import CLIPFeatureExtractor, XLMRobertaTokenizer + +from ...configuration_utils import FrozenDict +from ...models import AutoencoderKL, UNet2DConditionModel +from ...pipeline_utils import DiffusionPipeline +from ...schedulers import ( + DDIMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, +) +from ...utils import PIL_INTERPOLATION, deprecate, logging +from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from . import AltDiffusionPipelineOutput, RobertaSeriesModelWithTransformation + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess +def preprocess(image): + w, h = image.size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + return 2.0 * image - 1.0 + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker +class AltDiffusionImg2ImgPipeline(DiffusionPipeline): + r""" + Pipeline for text-guided image to image generation using Alt Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`RobertaSeriesModelWithTransformation`]): + Frozen text-encoder. Alt Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.RobertaSeriesModelWithTransformation), + specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`XLMRobertaTokenizer`): + Tokenizer of class + [XLMRobertaTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.XLMRobertaTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: RobertaSeriesModelWithTransformation, + tokenizer: XLMRobertaTokenizer, + unet: UNet2DConditionModel, + scheduler: Union[ + DDIMScheduler, + PNDMScheduler, + LMSDiscreteScheduler, + EulerDiscreteScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + ], + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPFeatureExtractor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Alt Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, + `attention_head_dim` must be a multiple of `slice_size`. + """ + if slice_size == "auto": + if isinstance(self.unet.config.attention_head_dim, int): + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = self.unet.config.attention_head_dim // 2 + else: + # if `attention_head_dim` is a list, take the smallest head size + slice_size = min(self.unet.config.attention_head_dim) + + self.unet.set_attention_slice(slice_size) + + def disable_attention_slicing(self): + r""" + Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go + back to computing attention in one step. + """ + # set slice_size = `None` to disable `attention slicing` + self.enable_attention_slicing(None) + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + if self.safety_checker is not None: + # TODO(Patrick) - there is currently a bug with cpu offload of nn.Parameter in accelerate + # fix by only offloading self.safety_checker for now + cpu_offload(self.safety_checker.vision_model, device) + + @property + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + def enable_xformers_memory_efficient_attention(self): + r""" + Enable memory efficient attention as implemented in xformers. + + When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference + time. Speed up at training time is not guaranteed. + + Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention + is used. + """ + self.unet.set_use_memory_efficient_attention_xformers(True) + + def disable_xformers_memory_efficient_attention(self): + r""" + Disable memory efficient attention as implemented in xformers. + """ + self.unet.set_use_memory_efficient_attention_xformers(False) + + def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list(int)`): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + """ + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids + + if not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + text_embeddings = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + text_embeddings = text_embeddings[0] + + # duplicate text embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = text_embeddings.shape + text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) + text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + uncond_embeddings = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + uncond_embeddings = uncond_embeddings[0] + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = uncond_embeddings.shape[1] + uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) + uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + return text_embeddings + + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + else: + has_nsfw_concept = None + return image, has_nsfw_concept + + def decode_latents(self, latents): + latents = 1 / 0.18215 * latents + image = self.vae.decode(latents).sample + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs(self, prompt, strength, callback_steps): + if not isinstance(prompt, str) and not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [1.0, 1.0] but is {strength}") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + offset = self.scheduler.config.get("steps_offset", 0) + init_timestep = int(num_inference_steps * strength) + offset + init_timestep = min(init_timestep, num_inference_steps) + + t_start = max(num_inference_steps - init_timestep + offset, 0) + timesteps = self.scheduler.timesteps[t_start:] + + return timesteps + + def prepare_latents(self, init_image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): + init_image = init_image.to(device=device, dtype=dtype) + init_latent_dist = self.vae.encode(init_image).latent_dist + init_latents = init_latent_dist.sample(generator=generator) + init_latents = 0.18215 * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" + " images (`init_image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many init images as text prompts to suppress this warning." + ) + deprecate("len(prompt) != len(init_image)", "1.0.0", deprecation_message, standard_warn=False) + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt * num_images_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `init_image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0) + + # add noise to latents using the timesteps + noise = torch.randn(init_latents.shape, generator=generator, device=device, dtype=dtype) + + # get latents + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + latents = init_latents + + return latents + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + init_image: Union[torch.FloatTensor, PIL.Image.Image], + strength: float = 0.8, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: Optional[float] = 0.0, + generator: Optional[torch.Generator] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + init_image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `init_image`. Must be between 0 and 1. + `init_image` will be used as a starting point, adding more noise to it the larger the `strength`. The + number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added + noise will be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `init_image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter will be modulated by `strength`. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Returns: + [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 1. Check inputs + self.check_inputs(prompt, strength, callback_steps) + + # 2. Define call parameters + batch_size = 1 if isinstance(prompt, str) else len(prompt) + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_embeddings = self._encode_prompt( + prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + # 4. Preprocess image + if isinstance(init_image, PIL.Image.Image): + init_image = preprocess(init_image) + + # 5. set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.get_timesteps(num_inference_steps, strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 6. Prepare latent variables + latents = self.prepare_latents( + init_image, latent_timestep, batch_size, num_images_per_prompt, text_embeddings.dtype, device, generator + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 8. Denoising loop + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # 9. Post-processing + image = self.decode_latents(latents) + + # 10. Run safety checker + image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype) + + # 11. Convert to PIL + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image, has_nsfw_concept) + + return AltDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/src/diffusers/pipelines/dance_diffusion/__init__.py b/src/diffusers/pipelines/dance_diffusion/__init__.py new file mode 100644 index 000000000000..2ad34fc52aaa --- /dev/null +++ b/src/diffusers/pipelines/dance_diffusion/__init__.py @@ -0,0 +1,2 @@ +# flake8: noqa +from .pipeline_dance_diffusion import DanceDiffusionPipeline diff --git a/src/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py b/src/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py new file mode 100644 index 000000000000..48d16889a030 --- /dev/null +++ b/src/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py @@ -0,0 +1,119 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Optional, Tuple, Union + +import torch + +from ...pipeline_utils import AudioPipelineOutput, DiffusionPipeline +from ...utils import logging + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class DanceDiffusionPipeline(DiffusionPipeline): + r""" + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Parameters: + unet ([`UNet1DModel`]): U-Net architecture to denoise the encoded image. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of + [`IPNDMScheduler`]. + """ + + def __init__(self, unet, scheduler): + super().__init__() + self.register_modules(unet=unet, scheduler=scheduler) + + @torch.no_grad() + def __call__( + self, + batch_size: int = 1, + num_inference_steps: int = 100, + generator: Optional[torch.Generator] = None, + audio_length_in_s: Optional[float] = None, + return_dict: bool = True, + ) -> Union[AudioPipelineOutput, Tuple]: + r""" + Args: + batch_size (`int`, *optional*, defaults to 1): + The number of audio samples to generate. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality audio sample at + the expense of slower inference. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + audio_length_in_s (`float`, *optional*, defaults to `self.unet.config.sample_size/self.unet.config.sample_rate`): + The length of the generated audio sample in seconds. Note that the output of the pipeline, *i.e.* + `sample_size`, will be `audio_length_in_s` * `self.unet.sample_rate`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipeline_utils.AudioPipelineOutput`] instead of a plain tuple. + + Returns: + [`~pipeline_utils.AudioPipelineOutput`] or `tuple`: [`~pipelines.utils.AudioPipelineOutput`] if + `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the + generated images. + """ + + if audio_length_in_s is None: + audio_length_in_s = self.unet.config.sample_size / self.unet.config.sample_rate + + sample_size = audio_length_in_s * self.unet.sample_rate + + down_scale_factor = 2 ** len(self.unet.up_blocks) + if sample_size < 3 * down_scale_factor: + raise ValueError( + f"{audio_length_in_s} is too small. Make sure it's bigger or equal to" + f" {3 * down_scale_factor / self.unet.sample_rate}." + ) + + original_sample_size = int(sample_size) + if sample_size % down_scale_factor != 0: + sample_size = ((audio_length_in_s * self.unet.sample_rate) // down_scale_factor + 1) * down_scale_factor + logger.info( + f"{audio_length_in_s} is increased to {sample_size / self.unet.sample_rate} so that it can be handled" + f" by the model. It will be cut to {original_sample_size / self.unet.sample_rate} after the denoising" + " process." + ) + sample_size = int(sample_size) + + dtype = next(iter(self.unet.parameters())).dtype + audio = torch.randn( + (batch_size, self.unet.in_channels, sample_size), generator=generator, device=self.device, dtype=dtype + ) + + # set step values + self.scheduler.set_timesteps(num_inference_steps, device=audio.device) + self.scheduler.timesteps = self.scheduler.timesteps.to(dtype) + + for t in self.progress_bar(self.scheduler.timesteps): + # 1. predict noise model_output + model_output = self.unet(audio, t).sample + + # 2. compute previous image: x_t -> t_t-1 + audio = self.scheduler.step(model_output, t, audio).prev_sample + + audio = audio.clamp(-1, 1).float().cpu().numpy() + + audio = audio[:, :, :original_sample_size] + + if not return_dict: + return (audio,) + + return AudioPipelineOutput(audios=audio) diff --git a/src/diffusers/pipelines/ddim/pipeline_ddim.py b/src/diffusers/pipelines/ddim/pipeline_ddim.py index 74607fe87a3d..b9e590dea646 100644 --- a/src/diffusers/pipelines/ddim/pipeline_ddim.py +++ b/src/diffusers/pipelines/ddim/pipeline_ddim.py @@ -10,15 +10,14 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and - # limitations under the License. - from typing import Optional, Tuple, Union import torch from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from ...utils import deprecate class DDIMPipeline(DiffusionPipeline): @@ -44,6 +43,7 @@ def __call__( generator: Optional[torch.Generator] = None, eta: float = 0.0, num_inference_steps: int = 50, + use_clipped_model_output: Optional[bool] = None, output_type: Optional[str] = "pil", return_dict: bool = True, **kwargs, @@ -60,6 +60,9 @@ def __call__( num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. + use_clipped_model_output (`bool`, *optional*, defaults to `None`): + if `True` or `False`, see documentation for `DDIMScheduler.step`. If `None`, nothing is passed + downstream to the scheduler. So use `None` for schedulers which don't support this argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. @@ -72,12 +75,31 @@ def __call__( generated images. """ + if generator is not None and generator.device.type != self.device.type and self.device.type != "mps": + message = ( + f"The `generator` device is `{generator.device}` and does not match the pipeline " + f"device `{self.device}`, so the `generator` will be ignored. " + f'Please use `generator=torch.Generator(device="{self.device}")` instead.' + ) + deprecate( + "generator.device == 'cpu'", + "0.11.0", + message, + ) + generator = None + # Sample gaussian noise to begin loop - image = torch.randn( - (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), - generator=generator, - ) - image = image.to(self.device) + if isinstance(self.unet.sample_size, int): + image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size) + else: + image_shape = (batch_size, self.unet.in_channels, *self.unet.sample_size) + + if self.device.type == "mps": + # randn does not work reproducibly on mps + image = torch.randn(image_shape, generator=generator) + image = image.to(self.device) + else: + image = torch.randn(image_shape, generator=generator, device=self.device) # set step values self.scheduler.set_timesteps(num_inference_steps) @@ -89,7 +111,9 @@ def __call__( # 2. predict previous mean of image x_t-1 and add variance depending on eta # eta corresponds to η in paper and should be between [0, 1] # do x_t -> x_t-1 - image = self.scheduler.step(model_output, t, image, eta).prev_sample + image = self.scheduler.step( + model_output, t, image, eta=eta, use_clipped_model_output=use_clipped_model_output, generator=generator + ).prev_sample image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() diff --git a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py index aae29737aae3..31791caf9ebe 100644 --- a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py +++ b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py @@ -10,7 +10,6 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and - # limitations under the License. @@ -18,7 +17,9 @@ import torch +from ...configuration_utils import FrozenDict from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from ...utils import deprecate class DDPMPipeline(DiffusionPipeline): @@ -42,6 +43,7 @@ def __call__( self, batch_size: int = 1, generator: Optional[torch.Generator] = None, + num_inference_steps: int = 1000, output_type: Optional[str] = "pil", return_dict: bool = True, **kwargs, @@ -53,6 +55,9 @@ def __call__( generator (`torch.Generator`, *optional*): A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. + num_inference_steps (`int`, *optional*, defaults to 1000): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. @@ -64,22 +69,51 @@ def __call__( `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. """ + message = ( + "Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler =" + " DDPMScheduler.from_pretrained(, prediction_type='epsilon')`." + ) + predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs) + + if predict_epsilon is not None: + new_config = dict(self.scheduler.config) + new_config["prediction_type"] = "epsilon" if predict_epsilon else "sample" + self.scheduler._internal_dict = FrozenDict(new_config) + + if generator is not None and generator.device.type != self.device.type and self.device.type != "mps": + message = ( + f"The `generator` device is `{generator.device}` and does not match the pipeline " + f"device `{self.device}`, so the `generator` will be ignored. " + f'Please use `torch.Generator(device="{self.device}")` instead.' + ) + deprecate( + "generator.device == 'cpu'", + "0.11.0", + message, + ) + generator = None # Sample gaussian noise to begin loop - image = torch.randn( - (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), - generator=generator, - ) - image = image.to(self.device) + if isinstance(self.unet.sample_size, int): + image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size) + else: + image_shape = (batch_size, self.unet.in_channels, *self.unet.sample_size) + + if self.device.type == "mps": + # randn does not work reproducibly on mps + image = torch.randn(image_shape, generator=generator) + image = image.to(self.device) + else: + image = torch.randn(image_shape, generator=generator, device=self.device) # set step values - self.scheduler.set_timesteps(1000) + self.scheduler.set_timesteps(num_inference_steps) for t in self.progress_bar(self.scheduler.timesteps): # 1. predict noise model_output model_output = self.unet(image, t).sample - # 2. compute previous image: x_t -> t_t-1 + # 2. compute previous image: x_t -> x_t-1 image = self.scheduler.step(model_output, t, image, generator=generator).prev_sample image = (image / 2 + 0.5).clamp(0, 1) diff --git a/src/diffusers/pipelines/latent_diffusion/__init__.py b/src/diffusers/pipelines/latent_diffusion/__init__.py index c481b38cf5e0..5544527ff587 100644 --- a/src/diffusers/pipelines/latent_diffusion/__init__.py +++ b/src/diffusers/pipelines/latent_diffusion/__init__.py @@ -1,5 +1,6 @@ # flake8: noqa from ...utils import is_transformers_available +from .pipeline_latent_diffusion_superresolution import LDMSuperResolutionPipeline if is_transformers_available(): diff --git a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py index 62a5785beb2f..0e903cb836a3 100644 --- a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +++ b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py @@ -1,3 +1,17 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import inspect from typing import List, Optional, Tuple, Union @@ -32,7 +46,7 @@ class LDMTextToImagePipeline(DiffusionPipeline): [BertTokenizer](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. """ @@ -46,13 +60,14 @@ def __init__( ): super().__init__() self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, scheduler=scheduler) + self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1) @torch.no_grad() def __call__( self, prompt: Union[str, List[str]], - height: Optional[int] = 256, - width: Optional[int] = 256, + height: Optional[int] = None, + width: Optional[int] = None, num_inference_steps: Optional[int] = 50, guidance_scale: Optional[float] = 1.0, eta: Optional[float] = 0.0, @@ -65,9 +80,9 @@ def __call__( Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. - height (`int`, *optional*, defaults to 256): + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. - width (`int`, *optional*, defaults to 256): + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the @@ -92,6 +107,9 @@ def __call__( `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor if isinstance(prompt, str): batch_size = 1 @@ -662,6 +680,8 @@ def custom_forward(*inputs): class LDMBertModel(LDMBertPreTrainedModel): + _no_split_modules = [] + def __init__(self, config: LDMBertConfig): super().__init__(config) self.model = LDMBertEncoder(config) diff --git a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py new file mode 100644 index 000000000000..b296a4953f97 --- /dev/null +++ b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py @@ -0,0 +1,170 @@ +import inspect +from typing import Optional, Tuple, Union + +import numpy as np +import torch +import torch.utils.checkpoint + +import PIL + +from ...models import UNet2DModel, VQModel +from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from ...schedulers import ( + DDIMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, +) +from ...utils import PIL_INTERPOLATION + + +def preprocess(image): + w, h = image.size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + return 2.0 * image - 1.0 + + +class LDMSuperResolutionPipeline(DiffusionPipeline): + r""" + A pipeline for image super-resolution using Latent + + This class inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Parameters: + vqvae ([`VQModel`]): + Vector-quantized (VQ) VAE Model to encode and decode images to and from latent representations. + unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`], + [`EulerAncestralDiscreteScheduler`], [`DPMSolverMultistepScheduler`], or [`PNDMScheduler`]. + """ + + def __init__( + self, + vqvae: VQModel, + unet: UNet2DModel, + scheduler: Union[ + DDIMScheduler, + PNDMScheduler, + LMSDiscreteScheduler, + EulerDiscreteScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + ], + ): + super().__init__() + self.register_modules(vqvae=vqvae, unet=unet, scheduler=scheduler) + + @torch.no_grad() + def __call__( + self, + init_image: Union[torch.Tensor, PIL.Image.Image], + batch_size: Optional[int] = 1, + num_inference_steps: Optional[int] = 100, + eta: Optional[float] = 0.0, + generator: Optional[torch.Generator] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + **kwargs, + ) -> Union[Tuple, ImagePipelineOutput]: + r""" + Args: + init_image (`torch.Tensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. + batch_size (`int`, *optional*, defaults to 1): + Number of images to generate. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*): + Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple. + + Returns: + [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if + `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the + generated images. + """ + + if isinstance(init_image, PIL.Image.Image): + batch_size = 1 + elif isinstance(init_image, torch.Tensor): + batch_size = init_image.shape[0] + else: + raise ValueError( + f"`init_image` has to be of type `PIL.Image.Image` or `torch.Tensor` but is {type(init_image)}" + ) + + if isinstance(init_image, PIL.Image.Image): + init_image = preprocess(init_image) + + height, width = init_image.shape[-2:] + + # in_channels should be 6: 3 for latents, 3 for low resolution image + latents_shape = (batch_size, self.unet.in_channels // 2, height, width) + latents_dtype = next(self.unet.parameters()).dtype + + if self.device.type == "mps": + # randn does not work reproducibly on mps + latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype) + latents = latents.to(self.device) + else: + latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype) + + init_image = init_image.to(device=self.device, dtype=latents_dtype) + + # set timesteps and move to the correct device + self.scheduler.set_timesteps(num_inference_steps, device=self.device) + timesteps_tensor = self.scheduler.timesteps + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature. + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_kwargs = {} + if accepts_eta: + extra_kwargs["eta"] = eta + + for t in self.progress_bar(timesteps_tensor): + # concat latents and low resolution image in the channel dimension. + latents_input = torch.cat([latents, init_image], dim=1) + latents_input = self.scheduler.scale_model_input(latents_input, t) + # predict the noise residual + noise_pred = self.unet(latents_input, t).sample + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs).prev_sample + + # decode the image latents with the VQVAE + image = self.vqvae.decode(latents).sample + image = torch.clamp(image, -1.0, 1.0) + image = image / 2 + 0.5 + image = image.cpu().permute(0, 2, 3, 1).numpy() + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py b/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py index ef82abb7e6cb..5345c4e5625e 100644 --- a/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +++ b/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py @@ -1,3 +1,17 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import inspect from typing import Optional, Tuple, Union @@ -18,7 +32,7 @@ class LDMPipeline(DiffusionPipeline): Vector-quantized (VQ) Model to encode and decode images to and from latent representations. unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): - [`DDIMScheduler`] is to be used in combination with `unet` to denoise the encoded image latens. + [`DDIMScheduler`] is to be used in combination with `unet` to denoise the encoded image latents. """ def __init__(self, vqvae: VQModel, unet: UNet2DModel, scheduler: DDIMScheduler): @@ -64,6 +78,9 @@ def __call__( ) latents = latents.to(self.device) + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + self.scheduler.set_timesteps(num_inference_steps) # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature @@ -74,8 +91,9 @@ def __call__( extra_kwargs["eta"] = eta for t in self.progress_bar(self.scheduler.timesteps): + latent_model_input = self.scheduler.scale_model_input(latents, t) # predict the noise residual - noise_prediction = self.unet(latents, t).sample + noise_prediction = self.unet(latent_model_input, t).sample # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_prediction, t, latents, **extra_kwargs).prev_sample diff --git a/src/diffusers/pipelines/pndm/pipeline_pndm.py b/src/diffusers/pipelines/pndm/pipeline_pndm.py index f360da09ac94..ef7062dea19c 100644 --- a/src/diffusers/pipelines/pndm/pipeline_pndm.py +++ b/src/diffusers/pipelines/pndm/pipeline_pndm.py @@ -10,7 +10,6 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and - # limitations under the License. diff --git a/src/diffusers/pipelines/repaint/__init__.py b/src/diffusers/pipelines/repaint/__init__.py new file mode 100644 index 000000000000..16bc86d1cedf --- /dev/null +++ b/src/diffusers/pipelines/repaint/__init__.py @@ -0,0 +1 @@ +from .pipeline_repaint import RePaintPipeline diff --git a/src/diffusers/pipelines/repaint/pipeline_repaint.py b/src/diffusers/pipelines/repaint/pipeline_repaint.py new file mode 100644 index 000000000000..7af88f627559 --- /dev/null +++ b/src/diffusers/pipelines/repaint/pipeline_repaint.py @@ -0,0 +1,140 @@ +# Copyright 2022 ETH Zurich Computer Vision Lab and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Optional, Tuple, Union + +import numpy as np +import torch + +import PIL +from tqdm.auto import tqdm + +from ...models import UNet2DModel +from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from ...schedulers import RePaintScheduler + + +def _preprocess_image(image: PIL.Image.Image): + image = np.array(image.convert("RGB")) + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 + return image + + +def _preprocess_mask(mask: PIL.Image.Image): + mask = np.array(mask.convert("L")) + mask = mask.astype(np.float32) / 255.0 + mask = mask[None, None] + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + mask = torch.from_numpy(mask) + return mask + + +class RePaintPipeline(DiffusionPipeline): + unet: UNet2DModel + scheduler: RePaintScheduler + + def __init__(self, unet, scheduler): + super().__init__() + self.register_modules(unet=unet, scheduler=scheduler) + + @torch.no_grad() + def __call__( + self, + original_image: Union[torch.FloatTensor, PIL.Image.Image], + mask_image: Union[torch.FloatTensor, PIL.Image.Image], + num_inference_steps: int = 250, + eta: float = 0.0, + jump_length: int = 10, + jump_n_sample: int = 10, + generator: Optional[torch.Generator] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + ) -> Union[ImagePipelineOutput, Tuple]: + r""" + Args: + original_image (`torch.FloatTensor` or `PIL.Image.Image`): + The original image to inpaint on. + mask_image (`torch.FloatTensor` or `PIL.Image.Image`): + The mask_image where 0.0 values define which part of the original image to inpaint (change). + num_inference_steps (`int`, *optional*, defaults to 1000): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + eta (`float`): + The weight of noise for added noise in a diffusion step. Its value is between 0.0 and 1.0 - 0.0 is DDIM + and 1.0 is DDPM scheduler respectively. + jump_length (`int`, *optional*, defaults to 10): + The number of steps taken forward in time before going backward in time for a single jump ("j" in + RePaint paper). Take a look at Figure 9 and 10 in https://arxiv.org/pdf/2201.09865.pdf. + jump_n_sample (`int`, *optional*, defaults to 10): + The number of times we will make forward time jump for a given chosen time sample. Take a look at + Figure 9 and 10 in https://arxiv.org/pdf/2201.09865.pdf. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple. + + Returns: + [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if + `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the + generated images. + """ + + if not isinstance(original_image, torch.FloatTensor): + original_image = _preprocess_image(original_image) + original_image = original_image.to(self.device) + if not isinstance(mask_image, torch.FloatTensor): + mask_image = _preprocess_mask(mask_image) + mask_image = mask_image.to(self.device) + + # sample gaussian noise to begin the loop + image = torch.randn( + original_image.shape, + generator=generator, + device=self.device, + ) + image = image.to(self.device) + + # set step values + self.scheduler.set_timesteps(num_inference_steps, jump_length, jump_n_sample, self.device) + self.scheduler.eta = eta + + t_last = self.scheduler.timesteps[0] + 1 + for i, t in enumerate(tqdm(self.scheduler.timesteps)): + if t < t_last: + # predict the noise residual + model_output = self.unet(image, t).sample + # compute previous image: x_t -> x_t-1 + image = self.scheduler.step(model_output, t, image, original_image, mask_image, generator).prev_sample + + else: + # compute the reverse: x_t-1 -> x_t + image = self.scheduler.undo_step(image, t_last, generator) + t_last = t + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py b/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py index 7f63820eec28..7eb6a5d3cbd4 100644 --- a/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py +++ b/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py @@ -1,4 +1,17 @@ -#!/usr/bin/env python3 +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import Optional, Tuple, Union import torch diff --git a/src/diffusers/pipelines/stable_diffusion/README.md b/src/diffusers/pipelines/stable_diffusion/README.md index 47c38acbdb35..bc30be4a7b9d 100644 --- a/src/diffusers/pipelines/stable_diffusion/README.md +++ b/src/diffusers/pipelines/stable_diffusion/README.md @@ -13,7 +13,7 @@ The summary of the model is the following: - Stable Diffusion has the same architecture as [Latent Diffusion](https://arxiv.org/abs/2112.10752) but uses a frozen CLIP Text Encoder instead of training the text encoder jointly with the diffusion model. - An in-detail explanation of the Stable Diffusion model can be found under [Stable Diffusion with 🧨 Diffusers](https://huggingface.co/blog/stable_diffusion). - If you don't want to rely on the Hugging Face Hub and having to pass a authentication token, you can -download the weights with `git lfs install; git clone https://huggingface.co/CompVis/stable-diffusion-v1-4` and instead pass the local path to the cloned folder to `from_pretrained` as shown below. +download the weights with `git lfs install; git clone https://huggingface.co/runwayml/stable-diffusion-v1-5` and instead pass the local path to the cloned folder to `from_pretrained` as shown below. - Stable Diffusion can work with a variety of different samplers as is shown below. ## Available Pipelines: @@ -33,14 +33,14 @@ If you want to download the model weights using a single Python line, you need t ```python from diffusers import DiffusionPipeline -pipeline = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4") +pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") ``` -This however can make it difficult to build applications on top of `diffusers` as you will always have to pass the token around. A potential way to solve this issue is by downloading the weights to a local path `"./stable-diffusion-v1-4"`: +This however can make it difficult to build applications on top of `diffusers` as you will always have to pass the token around. A potential way to solve this issue is by downloading the weights to a local path `"./stable-diffusion-v1-5"`: ``` git lfs install -git clone https://huggingface.co/CompVis/stable-diffusion-v1-4 +git clone https://huggingface.co/runwayml/stable-diffusion-v1-5 ``` and simply passing the local path to `from_pretrained`: @@ -48,7 +48,7 @@ and simply passing the local path to `from_pretrained`: ```python from diffusers import StableDiffusionPipeline -pipe = StableDiffusionPipeline.from_pretrained("./stable-diffusion-v1-4") +pipe = StableDiffusionPipeline.from_pretrained("./stable-diffusion-v1-5") ``` ### Text-to-Image with default PLMS scheduler @@ -57,7 +57,7 @@ pipe = StableDiffusionPipeline.from_pretrained("./stable-diffusion-v1-4") # make sure you're logged in with `huggingface-cli login` from diffusers import StableDiffusionPipeline -pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4") +pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") pipe = pipe.to("cuda") prompt = "a photo of an astronaut riding a horse on mars" @@ -72,10 +72,10 @@ image.save("astronaut_rides_horse.png") # make sure you're logged in with `huggingface-cli login` from diffusers import StableDiffusionPipeline, DDIMScheduler -scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False) +scheduler = DDIMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler") pipe = StableDiffusionPipeline.from_pretrained( - "CompVis/stable-diffusion-v1-4", + "runwayml/stable-diffusion-v1-5", scheduler=scheduler, ).to("cuda") @@ -91,14 +91,10 @@ image.save("astronaut_rides_horse.png") # make sure you're logged in with `huggingface-cli login` from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler -lms = LMSDiscreteScheduler( - beta_start=0.00085, - beta_end=0.012, - beta_schedule="scaled_linear" -) +lms = LMSDiscreteScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler") pipe = StableDiffusionPipeline.from_pretrained( - "CompVis/stable-diffusion-v1-4", + "runwayml/stable-diffusion-v1-5", scheduler=lms, ).to("cuda") @@ -107,3 +103,74 @@ image = pipe(prompt).sample[0] image.save("astronaut_rides_horse.png") ``` + +### CycleDiffusion using Stable Diffusion and DDIM scheduler + +```python +import requests +import torch +from PIL import Image +from io import BytesIO + +from diffusers import CycleDiffusionPipeline, DDIMScheduler + + +# load the scheduler. CycleDiffusion only supports stochastic schedulers. + +# load the pipeline +# make sure you're logged in with `huggingface-cli login` +model_id_or_path = "CompVis/stable-diffusion-v1-4" +scheduler = DDIMScheduler.from_pretrained(model_id_or_path, subfolder="scheduler") +pipe = CycleDiffusionPipeline.from_pretrained(model_id_or_path, scheduler=scheduler).to("cuda") + +# let's download an initial image +url = "https://raw.githubusercontent.com/ChenWu98/cycle-diffusion/main/data/dalle2/An%20astronaut%20riding%20a%20horse.png" +response = requests.get(url) +init_image = Image.open(BytesIO(response.content)).convert("RGB") +init_image = init_image.resize((512, 512)) +init_image.save("horse.png") + +# let's specify a prompt +source_prompt = "An astronaut riding a horse" +prompt = "An astronaut riding an elephant" + +# call the pipeline +image = pipe( + prompt=prompt, + source_prompt=source_prompt, + init_image=init_image, + num_inference_steps=100, + eta=0.1, + strength=0.8, + guidance_scale=2, + source_guidance_scale=1, +).images[0] + +image.save("horse_to_elephant.png") + +# let's try another example +# See more samples at the original repo: https://github.com/ChenWu98/cycle-diffusion +url = "https://raw.githubusercontent.com/ChenWu98/cycle-diffusion/main/data/dalle2/A%20black%20colored%20car.png" +response = requests.get(url) +init_image = Image.open(BytesIO(response.content)).convert("RGB") +init_image = init_image.resize((512, 512)) +init_image.save("black.png") + +source_prompt = "A black colored car" +prompt = "A blue colored car" + +# call the pipeline +torch.manual_seed(0) +image = pipe( + prompt=prompt, + source_prompt=source_prompt, + init_image=init_image, + num_inference_steps=100, + eta=0.1, + strength=0.85, + guidance_scale=3, + source_guidance_scale=1, +).images[0] + +image.save("black_to_blue.png") +``` diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py index 8c07afe58fc2..0136ab565bcb 100644 --- a/src/diffusers/pipelines/stable_diffusion/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion/__init__.py @@ -6,7 +6,14 @@ import PIL from PIL import Image -from ...utils import BaseOutput, is_flax_available, is_onnx_available, is_torch_available, is_transformers_available +from ...utils import ( + BaseOutput, + is_flax_available, + is_onnx_available, + is_torch_available, + is_transformers_available, + is_transformers_version, +) @dataclass @@ -28,13 +35,24 @@ class StableDiffusionPipelineOutput(BaseOutput): if is_transformers_available() and is_torch_available(): + from .pipeline_cycle_diffusion import CycleDiffusionPipeline from .pipeline_stable_diffusion import StableDiffusionPipeline from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline + from .pipeline_stable_diffusion_inpaint_legacy import StableDiffusionInpaintPipelineLegacy + from .pipeline_stable_diffusion_upscale import StableDiffusionUpscalePipeline from .safety_checker import StableDiffusionSafetyChecker +if is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0.dev0"): + from .pipeline_stable_diffusion_image_variation import StableDiffusionImageVariationPipeline +else: + from ...utils.dummy_torch_and_transformers_objects import StableDiffusionImageVariationPipeline + if is_transformers_available() and is_onnx_available(): - from .pipeline_stable_diffusion_onnx import StableDiffusionOnnxPipeline + from .pipeline_onnx_stable_diffusion import OnnxStableDiffusionPipeline, StableDiffusionOnnxPipeline + from .pipeline_onnx_stable_diffusion_img2img import OnnxStableDiffusionImg2ImgPipeline + from .pipeline_onnx_stable_diffusion_inpaint import OnnxStableDiffusionInpaintPipeline + from .pipeline_onnx_stable_diffusion_inpaint_legacy import OnnxStableDiffusionInpaintPipelineLegacy if is_transformers_available() and is_flax_available(): import flax diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py new file mode 100644 index 000000000000..83848905fd4c --- /dev/null +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py @@ -0,0 +1,699 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable, List, Optional, Union + +import numpy as np +import torch + +import PIL +from diffusers.utils import is_accelerate_available +from packaging import version +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer + +from ...configuration_utils import FrozenDict +from ...models import AutoencoderKL, UNet2DConditionModel +from ...pipeline_utils import DiffusionPipeline +from ...schedulers import DDIMScheduler +from ...utils import PIL_INTERPOLATION, deprecate, logging +from . import StableDiffusionPipelineOutput +from .safety_checker import StableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def preprocess(image): + w, h = image.size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + return 2.0 * image - 1.0 + + +def posterior_sample(scheduler, latents, timestep, clean_latents, generator, eta): + # 1. get previous step value (=t-1) + prev_timestep = timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps + + if prev_timestep <= 0: + return clean_latents + + # 2. compute alphas, betas + alpha_prod_t = scheduler.alphas_cumprod[timestep] + alpha_prod_t_prev = ( + scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else scheduler.final_alpha_cumprod + ) + + variance = scheduler._get_variance(timestep, prev_timestep) + std_dev_t = eta * variance ** (0.5) + + # direction pointing to x_t + e_t = (latents - alpha_prod_t ** (0.5) * clean_latents) / (1 - alpha_prod_t) ** (0.5) + dir_xt = (1.0 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * e_t + noise = std_dev_t * torch.randn( + clean_latents.shape, dtype=clean_latents.dtype, device=clean_latents.device, generator=generator + ) + prev_latents = alpha_prod_t_prev ** (0.5) * clean_latents + dir_xt + noise + + return prev_latents + + +def compute_noise(scheduler, prev_latents, latents, timestep, noise_pred, eta): + # 1. get previous step value (=t-1) + prev_timestep = timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps + + # 2. compute alphas, betas + alpha_prod_t = scheduler.alphas_cumprod[timestep] + alpha_prod_t_prev = ( + scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else scheduler.final_alpha_cumprod + ) + + beta_prod_t = 1 - alpha_prod_t + + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5) + + # 4. Clip "predicted x_0" + if scheduler.config.clip_sample: + pred_original_sample = torch.clamp(pred_original_sample, -1, 1) + + # 5. compute variance: "sigma_t(η)" -> see formula (16) + # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) + variance = scheduler._get_variance(timestep, prev_timestep) + std_dev_t = eta * variance ** (0.5) + + # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * noise_pred + + noise = (prev_latents - (alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction)) / ( + variance ** (0.5) * eta + ) + return noise + + +class CycleDiffusionPipeline(DiffusionPipeline): + r""" + Pipeline for text-guided image to image generation using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: DDIMScheduler, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPFeatureExtractor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, + `attention_head_dim` must be a multiple of `slice_size`. + """ + if slice_size == "auto": + if isinstance(self.unet.config.attention_head_dim, int): + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = self.unet.config.attention_head_dim // 2 + else: + # if `attention_head_dim` is a list, take the smallest head size + slice_size = min(self.unet.config.attention_head_dim) + + self.unet.set_attention_slice(slice_size) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing + def disable_attention_slicing(self): + r""" + Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go + back to computing attention in one step. + """ + # set slice_size = `None` to disable `attention slicing` + self.enable_attention_slicing(None) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + if self.safety_checker is not None: + # TODO(Patrick) - there is currently a bug with cpu offload of nn.Parameter in accelerate + # fix by only offloading self.safety_checker for now + cpu_offload(self.safety_checker.vision_model, device) + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention + def enable_xformers_memory_efficient_attention(self): + r""" + Enable memory efficient attention as implemented in xformers. + + When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference + time. Speed up at training time is not guaranteed. + + Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention + is used. + """ + self.unet.set_use_memory_efficient_attention_xformers(True) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_xformers_memory_efficient_attention + def disable_xformers_memory_efficient_attention(self): + r""" + Disable memory efficient attention as implemented in xformers. + """ + self.unet.set_use_memory_efficient_attention_xformers(False) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list(int)`): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + """ + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids + + if not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + text_embeddings = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + text_embeddings = text_embeddings[0] + + # duplicate text embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = text_embeddings.shape + text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) + text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + uncond_embeddings = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + uncond_embeddings = uncond_embeddings[0] + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = uncond_embeddings.shape[1] + uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) + uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + return text_embeddings + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.check_inputs + def check_inputs(self, prompt, strength, callback_steps): + if not isinstance(prompt, str) and not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [1.0, 1.0] but is {strength}") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + else: + has_nsfw_concept = None + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + latents = 1 / 0.18215 * latents + image = self.vae.decode(latents).sample + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + offset = self.scheduler.config.get("steps_offset", 0) + init_timestep = int(num_inference_steps * strength) + offset + init_timestep = min(init_timestep, num_inference_steps) + + t_start = max(num_inference_steps - init_timestep + offset, 0) + timesteps = self.scheduler.timesteps[t_start:] + + return timesteps + + def prepare_latents(self, init_image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): + init_image = init_image.to(device=device, dtype=dtype) + init_latent_dist = self.vae.encode(init_image).latent_dist + init_latents = init_latent_dist.sample(generator=generator) + init_latents = 0.18215 * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" + " images (`init_image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many init images as text prompts to suppress this warning." + ) + deprecate("len(prompt) != len(init_image)", "1.0.0", deprecation_message, standard_warn=False) + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt * num_images_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `init_image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0) + + # add noise to latents using the timestep + noise = torch.randn(init_latents.shape, generator=generator, device=device, dtype=dtype) + + # get latents + clean_latents = init_latents + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + latents = init_latents + + return latents, clean_latents + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + source_prompt: Union[str, List[str]], + init_image: Union[torch.FloatTensor, PIL.Image.Image], + strength: float = 0.8, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + source_guidance_scale: Optional[float] = 1, + num_images_per_prompt: Optional[int] = 1, + eta: Optional[float] = 0.1, + generator: Optional[torch.Generator] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + init_image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `init_image`. Must be between 0 and 1. + `init_image` will be used as a starting point, adding more noise to it the larger the `strength`. The + number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added + noise will be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `init_image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter will be modulated by `strength`. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + source_guidance_scale (`float`, *optional*, defaults to 1): + Guidance scale for the source prompt. This is useful to control the amount of influence the source + prompt for encoding. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.1): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 1. Check inputs + self.check_inputs(prompt, strength, callback_steps) + + # 2. Define call parameters + batch_size = 1 if isinstance(prompt, str) else len(prompt) + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_embeddings = self._encode_prompt(prompt, device, num_images_per_prompt, do_classifier_free_guidance, None) + source_text_embeddings = self._encode_prompt( + source_prompt, device, num_images_per_prompt, do_classifier_free_guidance, None + ) + + # 4. Preprocess image + if isinstance(init_image, PIL.Image.Image): + init_image = preprocess(init_image) + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.get_timesteps(num_inference_steps, strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 6. Prepare latent variables + latents, clean_latents = self.prepare_latents( + init_image, latent_timestep, batch_size, num_images_per_prompt, text_embeddings.dtype, device, generator + ) + source_latents = latents + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + generator = extra_step_kwargs.pop("generator", None) + + # 8. Denoising loop + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) + source_latent_model_input = torch.cat([source_latents] * 2) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + source_latent_model_input = self.scheduler.scale_model_input(source_latent_model_input, t) + + # predict the noise residual + concat_latent_model_input = torch.stack( + [ + source_latent_model_input[0], + latent_model_input[0], + source_latent_model_input[1], + latent_model_input[1], + ], + dim=0, + ) + concat_text_embeddings = torch.stack( + [ + source_text_embeddings[0], + text_embeddings[0], + source_text_embeddings[1], + text_embeddings[1], + ], + dim=0, + ) + concat_noise_pred = self.unet( + concat_latent_model_input, t, encoder_hidden_states=concat_text_embeddings + ).sample + + # perform guidance + ( + source_noise_pred_uncond, + noise_pred_uncond, + source_noise_pred_text, + noise_pred_text, + ) = concat_noise_pred.chunk(4, dim=0) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + source_noise_pred = source_noise_pred_uncond + source_guidance_scale * ( + source_noise_pred_text - source_noise_pred_uncond + ) + + # Sample source_latents from the posterior distribution. + prev_source_latents = posterior_sample( + self.scheduler, source_latents, t, clean_latents, generator=generator, **extra_step_kwargs + ) + # Compute noise. + noise = compute_noise( + self.scheduler, prev_source_latents, source_latents, t, source_noise_pred, **extra_step_kwargs + ) + source_latents = prev_source_latents + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step( + noise_pred, t, latents, variance_noise=noise, **extra_step_kwargs + ).prev_sample + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # 9. Post-processing + image = self.decode_latents(latents) + + # 10. Run safety checker + image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype) + + # 11. Convert to PIL + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py index 18c008f8806b..dbe3b7db9d38 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py @@ -1,3 +1,18 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings from functools import partial from typing import Dict, List, Optional, Union @@ -8,13 +23,19 @@ from flax.core.frozen_dict import FrozenDict from flax.jax_utils import unreplicate from flax.training.common_utils import shard +from packaging import version from PIL import Image from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel from ...pipeline_flax_utils import FlaxDiffusionPipeline -from ...schedulers import FlaxDDIMScheduler, FlaxLMSDiscreteScheduler, FlaxPNDMScheduler -from ...utils import logging +from ...schedulers import ( + FlaxDDIMScheduler, + FlaxDPMSolverMultistepScheduler, + FlaxLMSDiscreteScheduler, + FlaxPNDMScheduler, +) +from ...utils import deprecate, logging from . import FlaxStableDiffusionPipelineOutput from .safety_checker_flax import FlaxStableDiffusionSafetyChecker @@ -41,11 +62,12 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`FlaxUNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of - [`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], or [`FlaxPNDMScheduler`]. + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], [`FlaxPNDMScheduler`], or + [`FlaxDPMSolverMultistepScheduler`]. safety_checker ([`FlaxStableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. - Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPFeatureExtractor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ @@ -56,7 +78,9 @@ def __init__( text_encoder: FlaxCLIPTextModel, tokenizer: CLIPTokenizer, unet: FlaxUNet2DConditionModel, - scheduler: Union[FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler], + scheduler: Union[ + FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler + ], safety_checker: FlaxStableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, dtype: jnp.dtype = jnp.float32, @@ -65,7 +89,7 @@ def __init__( self.dtype = dtype if safety_checker is None: - logger.warn( + logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" @@ -74,6 +98,27 @@ def __init__( " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + self.register_modules( vae=vae, text_encoder=text_encoder, @@ -83,6 +128,7 @@ def __init__( safety_checker=safety_checker, feature_extractor=feature_extractor, ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) def prepare_inputs(self, prompt: Union[str, List[str]]): if not isinstance(prompt, (str, list)): @@ -97,9 +143,9 @@ def prepare_inputs(self, prompt: Union[str, List[str]]): ) return text_input.input_ids - def _get_safety_scores(self, features, params): - special_cos_dist, cos_dist = self.safety_checker(features, params) - return (special_cos_dist, cos_dist) + def _get_has_nsfw_concepts(self, features, params): + has_nsfw_concepts = self.safety_checker(features, params) + return has_nsfw_concepts def _run_safety_checker(self, images, safety_model_params, jit=False): # safety_model_params should already be replicated when jit is True @@ -108,20 +154,28 @@ def _run_safety_checker(self, images, safety_model_params, jit=False): if jit: features = shard(features) - special_cos_dist, cos_dist = _p_get_safety_scores(self, features, safety_model_params) - special_cos_dist = unshard(special_cos_dist) - cos_dist = unshard(cos_dist) + has_nsfw_concepts = _p_get_has_nsfw_concepts(self, features, safety_model_params) + has_nsfw_concepts = unshard(has_nsfw_concepts) safety_model_params = unreplicate(safety_model_params) else: - special_cos_dist, cos_dist = self._get_safety_scores(features, safety_model_params) + has_nsfw_concepts = self._get_has_nsfw_concepts(features, safety_model_params) - images, has_nsfw = self.safety_checker.filtered_with_scores( - special_cos_dist, - cos_dist, - images, - safety_model_params, - ) - return images, has_nsfw + images_was_copied = False + for idx, has_nsfw_concept in enumerate(has_nsfw_concepts): + if has_nsfw_concept: + if not images_was_copied: + images_was_copied = True + images = images.copy() + + images[idx] = np.zeros(images[idx].shape, dtype=np.uint8) # black image + + if any(has_nsfw_concepts): + warnings.warn( + "Potential NSFW content was detected in one or more images. A black image will be returned" + " instead. Try again with a different prompt and/or seed." + ) + + return images, has_nsfw_concepts def _generate( self, @@ -129,12 +183,17 @@ def _generate( params: Union[Dict, FrozenDict], prng_seed: jax.random.PRNGKey, num_inference_steps: int = 50, - height: int = 512, - width: int = 512, + height: Optional[int] = None, + width: Optional[int] = None, guidance_scale: float = 7.5, latents: Optional[jnp.array] = None, debug: bool = False, + neg_prompt_ids: jnp.array = None, ): + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") @@ -146,13 +205,22 @@ def _generate( batch_size = prompt_ids.shape[0] max_length = prompt_ids.shape[-1] - uncond_input = self.tokenizer( - [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np" - ) - uncond_embeddings = self.text_encoder(uncond_input.input_ids, params=params["text_encoder"])[0] + + if neg_prompt_ids is None: + uncond_input = self.tokenizer( + [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np" + ).input_ids + else: + uncond_input = neg_prompt_ids + uncond_embeddings = self.text_encoder(uncond_input, params=params["text_encoder"])[0] context = jnp.concatenate([uncond_embeddings, text_embeddings]) - latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) + latents_shape = ( + batch_size, + self.unet.in_channels, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) if latents is None: latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32) else: @@ -213,13 +281,14 @@ def __call__( params: Union[Dict, FrozenDict], prng_seed: jax.random.PRNGKey, num_inference_steps: int = 50, - height: int = 512, - width: int = 512, + height: Optional[int] = None, + width: Optional[int] = None, guidance_scale: float = 7.5, latents: jnp.array = None, return_dict: bool = True, jit: bool = False, debug: bool = False, + neg_prompt_ids: jnp.array = None, **kwargs, ): r""" @@ -228,9 +297,9 @@ def __call__( Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. - height (`int`, *optional*, defaults to 512): + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. - width (`int`, *optional*, defaults to 512): + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the @@ -265,13 +334,36 @@ def __call__( element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + if jit: images = _p_generate( - self, prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug + self, + prompt_ids, + params, + prng_seed, + num_inference_steps, + height, + width, + guidance_scale, + latents, + debug, + neg_prompt_ids, ) else: images = self._generate( - prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug + prompt_ids, + params, + prng_seed, + num_inference_steps, + height, + width, + guidance_scale, + latents, + debug, + neg_prompt_ids, ) if self.safety_checker is not None: @@ -302,16 +394,35 @@ def __call__( # TODO: maybe use a config dict instead of so many static argnums @partial(jax.pmap, static_broadcasted_argnums=(0, 4, 5, 6, 7, 9)) def _p_generate( - pipe, prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug + pipe, + prompt_ids, + params, + prng_seed, + num_inference_steps, + height, + width, + guidance_scale, + latents, + debug, + neg_prompt_ids, ): return pipe._generate( - prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug + prompt_ids, + params, + prng_seed, + num_inference_steps, + height, + width, + guidance_scale, + latents, + debug, + neg_prompt_ids, ) @partial(jax.pmap, static_broadcasted_argnums=(0,)) -def _p_get_safety_scores(pipe, features, params): - return pipe._get_safety_scores(features, params) +def _p_get_has_nsfw_concepts(pipe, features, params): + return pipe._get_has_nsfw_concepts(features, params) def unshard(x: jnp.ndarray): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py new file mode 100644 index 000000000000..6cb2c8ba87c7 --- /dev/null +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py @@ -0,0 +1,354 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable, List, Optional, Union + +import numpy as np +import torch + +from transformers import CLIPFeatureExtractor, CLIPTokenizer + +from ...configuration_utils import FrozenDict +from ...onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel +from ...pipeline_utils import DiffusionPipeline +from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from ...utils import deprecate, logging +from . import StableDiffusionPipelineOutput + + +logger = logging.get_logger(__name__) + + +class OnnxStableDiffusionPipeline(DiffusionPipeline): + vae_encoder: OnnxRuntimeModel + vae_decoder: OnnxRuntimeModel + text_encoder: OnnxRuntimeModel + tokenizer: CLIPTokenizer + unet: OnnxRuntimeModel + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] + safety_checker: OnnxRuntimeModel + feature_extractor: CLIPFeatureExtractor + + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae_encoder: OnnxRuntimeModel, + vae_decoder: OnnxRuntimeModel, + text_encoder: OnnxRuntimeModel, + tokenizer: CLIPTokenizer, + unet: OnnxRuntimeModel, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + safety_checker: OnnxRuntimeModel, + feature_extractor: CLIPFeatureExtractor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.register_modules( + vae_encoder=vae_encoder, + vae_decoder=vae_decoder, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list(int)`): + prompt to be encoded + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + """ + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="np", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids + + if not np.array_equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + text_embeddings = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] + text_embeddings = np.repeat(text_embeddings, num_images_per_prompt, axis=0) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] * batch_size + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="np", + ) + uncond_embeddings = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0] + uncond_embeddings = np.repeat(uncond_embeddings, num_images_per_prompt, axis=0) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = np.concatenate([uncond_embeddings, text_embeddings]) + + return text_embeddings + + def __call__( + self, + prompt: Union[str, List[str]], + height: Optional[int] = 512, + width: Optional[int] = 512, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: Optional[float] = 0.0, + generator: Optional[np.random.RandomState] = None, + latents: Optional[np.ndarray] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, np.ndarray], None]] = None, + callback_steps: Optional[int] = 1, + **kwargs, + ): + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if generator is None: + generator = np.random + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + text_embeddings = self._encode_prompt( + prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + # get the initial random noise unless the user supplied it + latents_dtype = text_embeddings.dtype + latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8) + if latents is None: + latents = generator.randn(*latents_shape).astype(latents_dtype) + elif latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + + # set timesteps + self.scheduler.set_timesteps(num_inference_steps) + + latents = latents * np.float(self.scheduler.init_noise_sigma) + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + timestep_dtype = next( + (input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)" + ) + timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype] + + for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t) + latent_model_input = latent_model_input.cpu().numpy() + + # predict the noise residual + timestep = np.array([t], dtype=timestep_dtype) + noise_pred = self.unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=text_embeddings) + noise_pred = noise_pred[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + scheduler_output = self.scheduler.step( + torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs + ) + latents = scheduler_output.prev_sample.numpy() + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + latents = 1 / 0.18215 * latents + # image = self.vae_decoder(latent_sample=latents)[0] + # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1 + image = np.concatenate( + [self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])] + ) + + image = np.clip(image / 2 + 0.5, 0, 1) + image = image.transpose((0, 2, 3, 1)) + + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor( + self.numpy_to_pil(image), return_tensors="np" + ).pixel_values.astype(image.dtype) + + image, has_nsfw_concepts = self.safety_checker(clip_input=safety_checker_input, images=image) + + # There will throw an error if use safety_checker batchsize>1 + images, has_nsfw_concept = [], [] + for i in range(image.shape[0]): + image_i, has_nsfw_concept_i = self.safety_checker( + clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1] + ) + images.append(image_i) + has_nsfw_concept.append(has_nsfw_concept_i[0]) + image = np.concatenate(images) + else: + has_nsfw_concept = None + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + + +class StableDiffusionOnnxPipeline(OnnxStableDiffusionPipeline): + def __init__( + self, + vae_encoder: OnnxRuntimeModel, + vae_decoder: OnnxRuntimeModel, + text_encoder: OnnxRuntimeModel, + tokenizer: CLIPTokenizer, + unet: OnnxRuntimeModel, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + safety_checker: OnnxRuntimeModel, + feature_extractor: CLIPFeatureExtractor, + ): + deprecation_message = "Please use `OnnxStableDiffusionPipeline` instead of `StableDiffusionOnnxPipeline`." + deprecate("StableDiffusionOnnxPipeline", "1.0.0", deprecation_message) + super().__init__( + vae_encoder=vae_encoder, + vae_decoder=vae_decoder, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py new file mode 100644 index 000000000000..949ef94b3a5a --- /dev/null +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py @@ -0,0 +1,455 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable, List, Optional, Union + +import numpy as np +import torch + +import PIL +from transformers import CLIPFeatureExtractor, CLIPTokenizer + +from ...configuration_utils import FrozenDict +from ...onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel +from ...pipeline_utils import DiffusionPipeline +from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from ...utils import PIL_INTERPOLATION, deprecate, logging +from . import StableDiffusionPipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def preprocess(image): + w, h = image.size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + return 2.0 * image - 1.0 + + +class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline): + r""" + Pipeline for text-guided image to image generation using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + vae_encoder: OnnxRuntimeModel + vae_decoder: OnnxRuntimeModel + text_encoder: OnnxRuntimeModel + tokenizer: CLIPTokenizer + unet: OnnxRuntimeModel + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] + safety_checker: OnnxRuntimeModel + feature_extractor: CLIPFeatureExtractor + + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae_encoder: OnnxRuntimeModel, + vae_decoder: OnnxRuntimeModel, + text_encoder: OnnxRuntimeModel, + tokenizer: CLIPTokenizer, + unet: OnnxRuntimeModel, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + safety_checker: OnnxRuntimeModel, + feature_extractor: CLIPFeatureExtractor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.register_modules( + vae_encoder=vae_encoder, + vae_decoder=vae_decoder, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt + def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list(int)`): + prompt to be encoded + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + """ + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="np", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids + + if not np.array_equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + text_embeddings = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] + text_embeddings = np.repeat(text_embeddings, num_images_per_prompt, axis=0) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] * batch_size + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="np", + ) + uncond_embeddings = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0] + uncond_embeddings = np.repeat(uncond_embeddings, num_images_per_prompt, axis=0) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = np.concatenate([uncond_embeddings, text_embeddings]) + + return text_embeddings + + def __call__( + self, + prompt: Union[str, List[str]], + init_image: Union[np.ndarray, PIL.Image.Image], + strength: float = 0.8, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: Optional[float] = 0.0, + generator: Optional[np.random.RandomState] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, np.ndarray], None]] = None, + callback_steps: Optional[int] = 1, + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + init_image (`np.ndarray` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `init_image`. Must be between 0 and 1. + `init_image` will be used as a starting point, adding more noise to it the larger the `strength`. The + number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added + noise will be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `init_image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter will be modulated by `strength`. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`np.random.RandomState`, *optional*): + A np.random.RandomState to make generation deterministic. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if generator is None: + generator = np.random + + # set timesteps + self.scheduler.set_timesteps(num_inference_steps) + + if isinstance(init_image, PIL.Image.Image): + init_image = preprocess(init_image) + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + text_embeddings = self._encode_prompt( + prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + latents_dtype = text_embeddings.dtype + init_image = init_image.astype(latents_dtype) + # encode the init image into latents and scale the latents + init_latents = self.vae_encoder(sample=init_image)[0] + init_latents = 0.18215 * init_latents + + if isinstance(prompt, str): + prompt = [prompt] + if len(prompt) > init_latents.shape[0] and len(prompt) % init_latents.shape[0] == 0: + # expand init_latents for batch_size + deprecation_message = ( + f"You have passed {len(prompt)} text prompts (`prompt`), but only {init_latents.shape[0]} initial" + " images (`init_image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many init images as text prompts to suppress this warning." + ) + deprecate("len(prompt) != len(init_image)", "1.0.0", deprecation_message, standard_warn=False) + additional_image_per_prompt = len(prompt) // init_latents.shape[0] + init_latents = np.concatenate([init_latents] * additional_image_per_prompt * num_images_per_prompt, axis=0) + elif len(prompt) > init_latents.shape[0] and len(prompt) % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `init_image` of batch size {init_latents.shape[0]} to {len(prompt)} text prompts." + ) + else: + init_latents = np.concatenate([init_latents] * num_images_per_prompt, axis=0) + + # get the original timestep using init_timestep + offset = self.scheduler.config.get("steps_offset", 0) + init_timestep = int(num_inference_steps * strength) + offset + init_timestep = min(init_timestep, num_inference_steps) + + timesteps = self.scheduler.timesteps.numpy()[-init_timestep] + timesteps = np.array([timesteps] * batch_size * num_images_per_prompt) + + # add noise to latents using the timesteps + noise = generator.randn(*init_latents.shape).astype(latents_dtype) + init_latents = self.scheduler.add_noise( + torch.from_numpy(init_latents), torch.from_numpy(noise), torch.from_numpy(timesteps) + ) + init_latents = init_latents.numpy() + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + latents = init_latents + + t_start = max(num_inference_steps - init_timestep + offset, 0) + timesteps = self.scheduler.timesteps[t_start:].numpy() + + timestep_dtype = next( + (input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)" + ) + timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype] + + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t) + latent_model_input = latent_model_input.cpu().numpy() + + # predict the noise residual + timestep = np.array([t], dtype=timestep_dtype) + noise_pred = self.unet( + sample=latent_model_input, timestep=timestep, encoder_hidden_states=text_embeddings + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + scheduler_output = self.scheduler.step( + torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs + ) + latents = scheduler_output.prev_sample.numpy() + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + latents = 1 / 0.18215 * latents + # image = self.vae_decoder(latent_sample=latents)[0] + # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1 + image = np.concatenate( + [self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])] + ) + + image = np.clip(image / 2 + 0.5, 0, 1) + image = image.transpose((0, 2, 3, 1)) + + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor( + self.numpy_to_pil(image), return_tensors="np" + ).pixel_values.astype(image.dtype) + # safety_checker does not support batched inputs yet + images, has_nsfw_concept = [], [] + for i in range(image.shape[0]): + image_i, has_nsfw_concept_i = self.safety_checker( + clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1] + ) + images.append(image_i) + has_nsfw_concept.append(has_nsfw_concept_i[0]) + image = np.concatenate(images) + else: + has_nsfw_concept = None + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py new file mode 100644 index 000000000000..0a8f7a5fc580 --- /dev/null +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py @@ -0,0 +1,479 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable, List, Optional, Union + +import numpy as np +import torch + +import PIL +from transformers import CLIPFeatureExtractor, CLIPTokenizer + +from ...configuration_utils import FrozenDict +from ...onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel +from ...pipeline_utils import DiffusionPipeline +from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from ...utils import PIL_INTERPOLATION, deprecate, logging +from . import StableDiffusionPipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +NUM_UNET_INPUT_CHANNELS = 9 +NUM_LATENT_CHANNELS = 4 + + +def prepare_mask_and_masked_image(image, mask, latents_shape): + image = np.array(image.convert("RGB").resize((latents_shape[1] * 8, latents_shape[0] * 8))) + image = image[None].transpose(0, 3, 1, 2) + image = image.astype(np.float32) / 127.5 - 1.0 + + image_mask = np.array(mask.convert("L").resize((latents_shape[1] * 8, latents_shape[0] * 8))) + masked_image = image * (image_mask < 127.5) + + mask = mask.resize((latents_shape[1], latents_shape[0]), PIL_INTERPOLATION["nearest"]) + mask = np.array(mask.convert("L")) + mask = mask.astype(np.float32) / 255.0 + mask = mask[None, None] + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + + return mask, masked_image + + +class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline): + r""" + Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + vae_encoder: OnnxRuntimeModel + vae_decoder: OnnxRuntimeModel + text_encoder: OnnxRuntimeModel + tokenizer: CLIPTokenizer + unet: OnnxRuntimeModel + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] + safety_checker: OnnxRuntimeModel + feature_extractor: CLIPFeatureExtractor + + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae_encoder: OnnxRuntimeModel, + vae_decoder: OnnxRuntimeModel, + text_encoder: OnnxRuntimeModel, + tokenizer: CLIPTokenizer, + unet: OnnxRuntimeModel, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + safety_checker: OnnxRuntimeModel, + feature_extractor: CLIPFeatureExtractor, + requires_safety_checker: bool = True, + ): + super().__init__() + logger.info("`OnnxStableDiffusionInpaintPipeline` is experimental and will very likely change in the future.") + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.register_modules( + vae_encoder=vae_encoder, + vae_decoder=vae_decoder, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt + def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list(int)`): + prompt to be encoded + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + """ + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="np", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids + + if not np.array_equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + text_embeddings = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] + text_embeddings = np.repeat(text_embeddings, num_images_per_prompt, axis=0) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] * batch_size + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="np", + ) + uncond_embeddings = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0] + uncond_embeddings = np.repeat(uncond_embeddings, num_images_per_prompt, axis=0) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = np.concatenate([uncond_embeddings, text_embeddings]) + + return text_embeddings + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + image: PIL.Image.Image, + mask_image: PIL.Image.Image, + height: Optional[int] = 512, + width: Optional[int] = 512, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[np.random.RandomState] = None, + latents: Optional[np.ndarray] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, np.ndarray], None]] = None, + callback_steps: Optional[int] = 1, + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + image (`PIL.Image.Image`): + `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will + be masked out with `mask_image` and repainted according to `prompt`. + mask_image (`PIL.Image.Image`): + `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be + repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted + to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) + instead of 3, so the expected shape would be `(B, H, W, 1)`. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`np.random.RandomState`, *optional*): + A np.random.RandomState to make generation deterministic. + latents (`np.ndarray`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if generator is None: + generator = np.random + + # set timesteps + self.scheduler.set_timesteps(num_inference_steps) + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + text_embeddings = self._encode_prompt( + prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + num_channels_latents = NUM_LATENT_CHANNELS + latents_shape = (batch_size * num_images_per_prompt, num_channels_latents, height // 8, width // 8) + latents_dtype = text_embeddings.dtype + if latents is None: + latents = generator.randn(*latents_shape).astype(latents_dtype) + else: + if latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + + # prepare mask and masked_image + mask, masked_image = prepare_mask_and_masked_image(image, mask_image, latents_shape[-2:]) + mask = mask.astype(latents.dtype) + masked_image = masked_image.astype(latents.dtype) + + masked_image_latents = self.vae_encoder(sample=masked_image)[0] + masked_image_latents = 0.18215 * masked_image_latents + + # duplicate mask and masked_image_latents for each generation per prompt + mask = mask.repeat(batch_size * num_images_per_prompt, 0) + masked_image_latents = masked_image_latents.repeat(batch_size * num_images_per_prompt, 0) + + mask = np.concatenate([mask] * 2) if do_classifier_free_guidance else mask + masked_image_latents = ( + np.concatenate([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents + ) + + num_channels_mask = mask.shape[1] + num_channels_masked_image = masked_image_latents.shape[1] + + unet_input_channels = NUM_UNET_INPUT_CHANNELS + if num_channels_latents + num_channels_mask + num_channels_masked_image != unet_input_channels: + raise ValueError( + "Incorrect configuration settings! The config of `pipeline.unet` expects" + f" {unet_input_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" + f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" + " `pipeline.unet` or your `mask_image` or `image` input." + ) + + # set timesteps + self.scheduler.set_timesteps(num_inference_steps) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * np.float(self.scheduler.init_noise_sigma) + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + timestep_dtype = next( + (input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)" + ) + timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype] + + for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents + # concat latents, mask, masked_image_latnets in the channel dimension + latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t) + latent_model_input = latent_model_input.cpu().numpy() + latent_model_input = np.concatenate([latent_model_input, mask, masked_image_latents], axis=1) + + # predict the noise residual + timestep = np.array([t], dtype=timestep_dtype) + noise_pred = self.unet( + sample=latent_model_input, timestep=timestep, encoder_hidden_states=text_embeddings + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + scheduler_output = self.scheduler.step( + torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs + ) + latents = scheduler_output.prev_sample.numpy() + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + latents = 1 / 0.18215 * latents + # image = self.vae_decoder(latent_sample=latents)[0] + # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1 + image = np.concatenate( + [self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])] + ) + + image = np.clip(image / 2 + 0.5, 0, 1) + image = image.transpose((0, 2, 3, 1)) + + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor( + self.numpy_to_pil(image), return_tensors="np" + ).pixel_values.astype(image.dtype) + # safety_checker does not support batched inputs yet + images, has_nsfw_concept = [], [] + for i in range(image.shape[0]): + image_i, has_nsfw_concept_i = self.safety_checker( + clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1] + ) + images.append(image_i) + has_nsfw_concept.append(has_nsfw_concept_i[0]) + image = np.concatenate(images) + else: + has_nsfw_concept = None + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py new file mode 100644 index 000000000000..631e7129e950 --- /dev/null +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py @@ -0,0 +1,478 @@ +import inspect +from typing import Callable, List, Optional, Union + +import numpy as np +import torch + +import PIL +from packaging import version +from transformers import CLIPFeatureExtractor, CLIPTokenizer + +from ...configuration_utils import FrozenDict +from ...onnx_utils import OnnxRuntimeModel +from ...pipeline_utils import DiffusionPipeline +from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from ...utils import deprecate, logging +from . import StableDiffusionPipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def preprocess(image): + w, h = image.size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + image = image.resize((w, h), resample=PIL.Image.LANCZOS) + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + return 2.0 * image - 1.0 + + +def preprocess_mask(mask, scale_factor=8): + mask = mask.convert("L") + w, h = mask.size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL.Image.NEAREST) + mask = np.array(mask).astype(np.float32) / 255.0 + mask = np.tile(mask, (4, 1, 1)) + mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? + mask = 1 - mask # repaint white, keep black + return mask + + +class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline): + r""" + Pipeline for text-guided image inpainting using Stable Diffusion. This is a *legacy feature* for Onnx pipelines to + provide compatibility with StableDiffusionInpaintPipelineLegacy and may be removed in the future. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + vae_encoder: OnnxRuntimeModel + vae_decoder: OnnxRuntimeModel + text_encoder: OnnxRuntimeModel + tokenizer: CLIPTokenizer + unet: OnnxRuntimeModel + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] + safety_checker: OnnxRuntimeModel + feature_extractor: CLIPFeatureExtractor + + def __init__( + self, + vae_encoder: OnnxRuntimeModel, + vae_decoder: OnnxRuntimeModel, + text_encoder: OnnxRuntimeModel, + tokenizer: CLIPTokenizer, + unet: OnnxRuntimeModel, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + safety_checker: OnnxRuntimeModel, + feature_extractor: CLIPFeatureExtractor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae_encoder=vae_encoder, + vae_decoder=vae_decoder, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt + def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list(int)`): + prompt to be encoded + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + """ + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="np", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids + + if not np.array_equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + text_embeddings = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] + text_embeddings = np.repeat(text_embeddings, num_images_per_prompt, axis=0) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] * batch_size + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="np", + ) + uncond_embeddings = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0] + uncond_embeddings = np.repeat(uncond_embeddings, num_images_per_prompt, axis=0) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = np.concatenate([uncond_embeddings, text_embeddings]) + + return text_embeddings + + def __call__( + self, + prompt: Union[str, List[str]], + init_image: Union[np.ndarray, PIL.Image.Image], + mask_image: Union[np.ndarray, PIL.Image.Image], + strength: float = 0.8, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: Optional[float] = 0.0, + generator: Optional[np.random.RandomState] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, np.ndarray], None]] = None, + callback_steps: Optional[int] = 1, + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + init_image (`nd.ndarray` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. This is the image whose masked region will be inpainted. + mask_image (`nd.ndarray` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, to mask `init_image`. White pixels in the mask will be + replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a + PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should + contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.uu + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `init_image`. Must be between 0 and 1. + `init_image` will be used as a starting point, adding more noise to it the larger the `strength`. The + number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added + noise will be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `init_image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter will be modulated by `strength`. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (?) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`np.random.RandomState`, *optional*): + A np.random.RandomState to make generation deterministic. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if generator is None: + generator = np.random + + # set timesteps + self.scheduler.set_timesteps(num_inference_steps) + + if isinstance(init_image, PIL.Image.Image): + init_image = preprocess(init_image) + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + text_embeddings = self._encode_prompt( + prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + latents_dtype = text_embeddings.dtype + init_image = init_image.astype(latents_dtype) + + # encode the init image into latents and scale the latents + init_latents = self.vae_encoder(sample=init_image)[0] + init_latents = 0.18215 * init_latents + + # Expand init_latents for batch_size and num_images_per_prompt + init_latents = np.concatenate([init_latents] * num_images_per_prompt, axis=0) + init_latents_orig = init_latents + + # preprocess mask + if not isinstance(mask_image, np.ndarray): + mask_image = preprocess_mask(mask_image, self.vae_scale_factor) + mask_image = mask_image.astype(latents_dtype) + mask = np.concatenate([mask_image] * num_images_per_prompt, axis=0) + + # check sizes + if not mask.shape == init_latents.shape: + raise ValueError("The mask and init_image should be the same size!") + + # get the original timestep using init_timestep + offset = self.scheduler.config.get("steps_offset", 0) + init_timestep = int(num_inference_steps * strength) + offset + init_timestep = min(init_timestep, num_inference_steps) + + timesteps = self.scheduler.timesteps.numpy()[-init_timestep] + timesteps = np.array([timesteps] * batch_size * num_images_per_prompt) + + # add noise to latents using the timesteps + noise = generator.randn(*init_latents.shape).astype(latents_dtype) + init_latents = self.scheduler.add_noise( + torch.from_numpy(init_latents), torch.from_numpy(noise), torch.from_numpy(timesteps) + ) + init_latents = init_latents.numpy() + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (?) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to ? in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + latents = init_latents + + t_start = max(num_inference_steps - init_timestep + offset, 0) + timesteps = self.scheduler.timesteps[t_start:].numpy() + + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + sample=latent_model_input, timestep=np.array([t]), encoder_hidden_states=text_embeddings + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step( + torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs + ).prev_sample + + latents = latents.numpy() + + init_latents_proper = self.scheduler.add_noise( + torch.from_numpy(init_latents_orig), torch.from_numpy(noise), torch.from_numpy(np.array([t])) + ) + + init_latents_proper = init_latents_proper.numpy() + + latents = (init_latents_proper * mask) + (latents * (1 - mask)) + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + latents = 1 / 0.18215 * latents + # image = self.vae_decoder(latent_sample=latents)[0] + # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1 + image = np.concatenate( + [self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])] + ) + + image = np.clip(image / 2 + 0.5, 0, 1) + image = image.transpose((0, 2, 3, 1)) + + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor( + self.numpy_to_pil(image), return_tensors="np" + ).pixel_values.astype(image.dtype) + # There will throw an error if use safety_checker batchsize>1 + images, has_nsfw_concept = [], [] + for i in range(image.shape[0]): + image_i, has_nsfw_concept_i = self.safety_checker( + clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1] + ) + images.append(image_i) + has_nsfw_concept.append(has_nsfw_concept_i[0]) + image = np.concatenate(images) + else: + has_nsfw_concept = None + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 8ae51999a7b3..3739ae7a6da0 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -1,14 +1,37 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import inspect from typing import Callable, List, Optional, Union import torch +from diffusers.utils import is_accelerate_available +from packaging import version from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from ...configuration_utils import FrozenDict from ...models import AutoencoderKL, UNet2DConditionModel from ...pipeline_utils import DiffusionPipeline -from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from ...schedulers import ( + DDIMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, +) from ...utils import deprecate, logging from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker @@ -36,14 +59,15 @@ class StableDiffusionPipeline(DiffusionPipeline): [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. - Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPFeatureExtractor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ + _optional_components = ["safety_checker", "feature_extractor"] def __init__( self, @@ -51,9 +75,17 @@ def __init__( text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, - scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + scheduler: Union[ + DDIMScheduler, + PNDMScheduler, + LMSDiscreteScheduler, + EulerDiscreteScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + ], safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, + requires_safety_checker: bool = True, ): super().__init__() @@ -71,8 +103,21 @@ def __init__( new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) - if safety_checker is None: - logger.warn( + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" @@ -81,6 +126,33 @@ def __init__( " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + self.register_modules( vae=vae, text_encoder=text_encoder, @@ -90,6 +162,26 @@ def __init__( safety_checker=safety_checker, feature_extractor=feature_extractor, ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + def enable_xformers_memory_efficient_attention(self): + r""" + Enable memory efficient attention as implemented in xformers. + + When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference + time. Speed up at training time is not guaranteed. + + Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention + is used. + """ + self.unet.set_use_memory_efficient_attention_xformers(True) + + def disable_xformers_memory_efficient_attention(self): + r""" + Disable memory efficient attention as implemented in xformers. + """ + self.unet.set_use_memory_efficient_attention_xformers(False) def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): r""" @@ -105,9 +197,14 @@ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto `attention_head_dim` must be a multiple of `slice_size`. """ if slice_size == "auto": - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = self.unet.config.attention_head_dim // 2 + if isinstance(self.unet.config.attention_head_dim, int): + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = self.unet.config.attention_head_dim // 2 + else: + # if `attention_head_dim` is a list, take the smallest head size + slice_size = min(self.unet.config.attention_head_dim) + self.unet.set_attention_slice(slice_size) def disable_attention_slicing(self): @@ -118,12 +215,224 @@ def disable_attention_slicing(self): # set slice_size = `None` to disable `attention slicing` self.enable_attention_slicing(None) + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + if self.safety_checker is not None: + # TODO(Patrick) - there is currently a bug with cpu offload of nn.Parameter in accelerate + # fix by only offloading self.safety_checker for now + cpu_offload(self.safety_checker.vision_model, device) + + @property + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list(int)`): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + """ + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids + + if not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + text_embeddings = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + text_embeddings = text_embeddings[0] + + # duplicate text embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = text_embeddings.shape + text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) + text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + uncond_embeddings = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + uncond_embeddings = uncond_embeddings[0] + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = uncond_embeddings.shape[1] + uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) + uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + return text_embeddings + + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + else: + has_nsfw_concept = None + return image, has_nsfw_concept + + def decode_latents(self, latents): + latents = 1 / 0.18215 * latents + image = self.vae.decode(latents).sample + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs(self, prompt, height, width, callback_steps): + if not isinstance(prompt, str) and not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if latents is None: + if device.type == "mps": + # randn does not work reproducibly on mps + latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device) + else: + latents = torch.randn(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + @torch.no_grad() def __call__( self, prompt: Union[str, List[str]], - height: int = 512, - width: int = 512, + height: Optional[int] = None, + width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, @@ -143,9 +452,9 @@ def __call__( Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. - height (`int`, *optional*, defaults to 512): + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. - width (`int`, *optional*, defaults to 512): + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the @@ -191,133 +500,48 @@ def __call__( list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor - if isinstance(prompt, str): - batch_size = 1 - elif isinstance(prompt, list): - batch_size = len(prompt) - else: - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) - - # get prompt text embeddings - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - - if text_input_ids.shape[-1] > self.tokenizer.model_max_length: - removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" - ) - text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] - text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] - - # duplicate text embeddings for each generation per prompt, using mps friendly method - bs_embed, seq_len, _ = text_embeddings.shape - text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) - text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, height, width, callback_steps) + # 2. Define call parameters + batch_size = 1 if isinstance(prompt, str) else len(prompt) + device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 - # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance: - uncond_tokens: List[str] - if negative_prompt is None: - uncond_tokens = [""] - elif type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt] - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_tokens = negative_prompt - - max_length = text_input_ids.shape[-1] - uncond_input = self.tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) - uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] - - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = uncond_embeddings.shape[1] - uncond_embeddings = uncond_embeddings.repeat(batch_size, num_images_per_prompt, 1) - uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) - - # get the initial random noise unless the user supplied it - # Unlike in other pipelines, latents need to be generated in the target device - # for 1-to-1 results reproducibility with the CompVis implementation. - # However this currently doesn't work in `mps`. - latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8) - latents_dtype = text_embeddings.dtype - if latents is None: - if self.device.type == "mps": - # randn does not exist on mps - latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to( - self.device - ) - else: - latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype) - else: - if latents.shape != latents_shape: - raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") - latents = latents.to(self.device) - - # set timesteps - self.scheduler.set_timesteps(num_inference_steps) + # 3. Encode input prompt + text_embeddings = self._encode_prompt( + prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ) - # Some schedulers like PNDM have timesteps as arrays - # It's more optimized to move all timesteps to correct device beforehand - timesteps_tensor = self.scheduler.timesteps.to(self.device) + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + text_embeddings.dtype, + device, + generator, + latents, + ) - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - for i, t in enumerate(self.progress_bar(timesteps_tensor)): + # 7. Denoising loop + for i, t in enumerate(self.progress_bar(timesteps)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) @@ -337,24 +561,13 @@ def __call__( if callback is not None and i % callback_steps == 0: callback(i, t, latents) - latents = 1 / 0.18215 * latents - image = self.vae.decode(latents).sample - - image = (image / 2 + 0.5).clamp(0, 1) + # 8. Post-processing + image = self.decode_latents(latents) - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 - image = image.cpu().permute(0, 2, 3, 1).float().numpy() - - if self.safety_checker is not None: - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to( - self.device - ) - image, has_nsfw_concept = self.safety_checker( - images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype) - ) - else: - has_nsfw_concept = None + # 9. Run safety checker + image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype) + # 10. Convert to PIL if output_type == "pil": image = self.numpy_to_pil(image) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py new file mode 100644 index 000000000000..5e6aa9885cf8 --- /dev/null +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py @@ -0,0 +1,478 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable, List, Optional, Union + +import torch + +import PIL +from diffusers.utils import is_accelerate_available +from packaging import version +from transformers import CLIPFeatureExtractor, CLIPVisionModelWithProjection + +from ...configuration_utils import FrozenDict +from ...models import AutoencoderKL, UNet2DConditionModel +from ...pipeline_utils import DiffusionPipeline +from ...schedulers import ( + DDIMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, +) +from ...utils import deprecate, logging +from . import StableDiffusionPipelineOutput +from .safety_checker import StableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class StableDiffusionImageVariationPipeline(DiffusionPipeline): + r""" + Pipeline to generate variations from an input image using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + image_encoder ([`CLIPVisionModelWithProjection`]): + Frozen CLIP image-encoder. Stable Diffusion Image Variation uses the vision portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModelWithProjection), + specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae: AutoencoderKL, + image_encoder: CLIPVisionModelWithProjection, + unet: UNet2DConditionModel, + scheduler: Union[ + DDIMScheduler, + PNDMScheduler, + LMSDiscreteScheduler, + EulerDiscreteScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + ], + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPFeatureExtractor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warn( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + image_encoder=image_encoder, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention + def enable_xformers_memory_efficient_attention(self): + r""" + Enable memory efficient attention as implemented in xformers. + + When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference + time. Speed up at training time is not guaranteed. + + Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention + is used. + """ + self.unet.set_use_memory_efficient_attention_xformers(True) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_xformers_memory_efficient_attention + def disable_xformers_memory_efficient_attention(self): + r""" + Disable memory efficient attention as implemented in xformers. + """ + self.unet.set_use_memory_efficient_attention_xformers(False) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, + `attention_head_dim` must be a multiple of `slice_size`. + """ + if slice_size == "auto": + if isinstance(self.unet.config.attention_head_dim, int): + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = self.unet.config.attention_head_dim // 2 + else: + # if `attention_head_dim` is a list, take the smallest head size + slice_size = min(self.unet.config.attention_head_dim) + + self.unet.set_attention_slice(slice_size) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing + def disable_attention_slicing(self): + r""" + Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go + back to computing attention in one step. + """ + # set slice_size = `None` to disable `attention slicing` + self.enable_attention_slicing(None) + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + for cpu_offloaded_model in [self.unet, self.image_encoder, self.vae, self.safety_checker]: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + def _encode_image(self, image, device, num_images_per_prompt, do_classifier_free_guidance): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(images=image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeddings = self.image_encoder(image).image_embeds + image_embeddings = image_embeddings.unsqueeze(1) + + # duplicate image embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = image_embeddings.shape + image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1) + image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + uncond_embeddings = torch.zeros_like(image_embeddings) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + image_embeddings = torch.cat([uncond_embeddings, image_embeddings]) + + return image_embeddings + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + else: + has_nsfw_concept = None + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + latents = 1 / 0.18215 * latents + image = self.vae.decode(latents).sample + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs(self, image, height, width, callback_steps): + if ( + not isinstance(image, torch.Tensor) + and not isinstance(image, PIL.Image.Image) + and not isinstance(image, list) + ): + raise ValueError( + f"`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `list` but is {type(image)}" + ) + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if latents is None: + if device.type == "mps": + # randn does not work reproducibly on mps + latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device) + else: + latents = torch.randn(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + def __call__( + self, + image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor], + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`): + The image or images to guide the image generation. If you provide a tensor, it needs to comply with the + configuration of + [this](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json) + `CLIPFeatureExtractor` + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs(image, height, width, callback_steps) + + # 2. Define call parameters + if isinstance(image, PIL.Image.Image): + batch_size = 1 + elif isinstance(image, list): + batch_size = len(image) + else: + batch_size = image.shape[0] + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input image + image_embeddings = self._encode_image(image, device, num_images_per_prompt, do_classifier_free_guidance) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + image_embeddings.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=image_embeddings).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # 8. Post-processing + image = self.decode_latents(latents) + + # 9. Run safety checker + image, has_nsfw_concept = self.run_safety_checker(image, device, image_embeddings.dtype) + + # 10. Convert to PIL + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 799fd459bbd9..8fe86992af20 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -1,3 +1,17 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import inspect from typing import Callable, List, Optional, Union @@ -5,13 +19,22 @@ import torch import PIL +from diffusers.utils import is_accelerate_available +from packaging import version from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from ...configuration_utils import FrozenDict from ...models import AutoencoderKL, UNet2DConditionModel from ...pipeline_utils import DiffusionPipeline -from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler -from ...utils import deprecate, logging +from ...schedulers import ( + DDIMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, +) +from ...utils import PIL_INTERPOLATION, deprecate, logging from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker @@ -22,7 +45,7 @@ def preprocess(image): w, h = image.size w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 - image = image.resize((w, h), resample=PIL.Image.LANCZOS) + image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) image = np.array(image).astype(np.float32) / 255.0 image = image[None].transpose(0, 3, 1, 2) image = torch.from_numpy(image) @@ -48,24 +71,34 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. - Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPFeatureExtractor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ + _optional_components = ["safety_checker", "feature_extractor"] + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.__init__ def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, - scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + scheduler: Union[ + DDIMScheduler, + PNDMScheduler, + LMSDiscreteScheduler, + EulerDiscreteScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + ], safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, + requires_safety_checker: bool = True, ): super().__init__() @@ -83,8 +116,21 @@ def __init__( new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) - if safety_checker is None: - logger.warn( + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" @@ -93,6 +139,33 @@ def __init__( " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + self.register_modules( vae=vae, text_encoder=text_encoder, @@ -102,7 +175,10 @@ def __init__( safety_checker=safety_checker, feature_extractor=feature_extractor, ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.register_to_config(requires_safety_checker=requires_safety_checker) + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): r""" Enable sliced attention computation. @@ -117,19 +193,290 @@ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto `attention_head_dim` must be a multiple of `slice_size`. """ if slice_size == "auto": - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = self.unet.config.attention_head_dim // 2 + if isinstance(self.unet.config.attention_head_dim, int): + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = self.unet.config.attention_head_dim // 2 + else: + # if `attention_head_dim` is a list, take the smallest head size + slice_size = min(self.unet.config.attention_head_dim) + self.unet.set_attention_slice(slice_size) + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing def disable_attention_slicing(self): r""" Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go back to computing attention in one step. """ - # set slice_size = `None` to disable `set_attention_slice` + # set slice_size = `None` to disable `attention slicing` self.enable_attention_slicing(None) + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + if self.safety_checker is not None: + # TODO(Patrick) - there is currently a bug with cpu offload of nn.Parameter in accelerate + # fix by only offloading self.safety_checker for now + cpu_offload(self.safety_checker.vision_model, device) + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention + def enable_xformers_memory_efficient_attention(self): + r""" + Enable memory efficient attention as implemented in xformers. + + When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference + time. Speed up at training time is not guaranteed. + + Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention + is used. + """ + self.unet.set_use_memory_efficient_attention_xformers(True) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_xformers_memory_efficient_attention + def disable_xformers_memory_efficient_attention(self): + r""" + Disable memory efficient attention as implemented in xformers. + """ + self.unet.set_use_memory_efficient_attention_xformers(False) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list(int)`): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + """ + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids + + if not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + text_embeddings = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + text_embeddings = text_embeddings[0] + + # duplicate text embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = text_embeddings.shape + text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) + text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + uncond_embeddings = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + uncond_embeddings = uncond_embeddings[0] + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = uncond_embeddings.shape[1] + uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) + uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + return text_embeddings + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + else: + has_nsfw_concept = None + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + latents = 1 / 0.18215 * latents + image = self.vae.decode(latents).sample + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs(self, prompt, strength, callback_steps): + if not isinstance(prompt, str) and not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [1.0, 1.0] but is {strength}") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + offset = self.scheduler.config.get("steps_offset", 0) + init_timestep = int(num_inference_steps * strength) + offset + init_timestep = min(init_timestep, num_inference_steps) + + t_start = max(num_inference_steps - init_timestep + offset, 0) + timesteps = self.scheduler.timesteps[t_start:] + + return timesteps + + def prepare_latents(self, init_image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): + init_image = init_image.to(device=device, dtype=dtype) + init_latent_dist = self.vae.encode(init_image).latent_dist + init_latents = init_latent_dist.sample(generator=generator) + init_latents = 0.18215 * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" + " images (`init_image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many init images as text prompts to suppress this warning." + ) + deprecate("len(prompt) != len(init_image)", "1.0.0", deprecation_message, standard_warn=False) + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt * num_images_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `init_image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0) + + # add noise to latents using the timesteps + noise = torch.randn(init_latents.shape, generator=generator, device=device, dtype=dtype) + + # get latents + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + latents = init_latents + + return latents + @torch.no_grad() def __call__( self, @@ -203,146 +550,40 @@ def __call__( list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ - if isinstance(prompt, str): - batch_size = 1 - elif isinstance(prompt, list): - batch_size = len(prompt) - else: - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if strength < 0 or strength > 1: - raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") - - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) - - # set timesteps - self.scheduler.set_timesteps(num_inference_steps) - - if isinstance(init_image, PIL.Image.Image): - init_image = preprocess(init_image) - - # get prompt text embeddings - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - - if text_input_ids.shape[-1] > self.tokenizer.model_max_length: - removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" - ) - text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] - text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] - - # duplicate text embeddings for each generation per prompt - text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0) + # 1. Check inputs + self.check_inputs(prompt, strength, callback_steps) + # 2. Define call parameters + batch_size = 1 if isinstance(prompt, str) else len(prompt) + device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 - # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance: - uncond_tokens: List[str] - if negative_prompt is None: - uncond_tokens = [""] - elif type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt] - elif batch_size != len(negative_prompt): - raise ValueError("The length of `negative_prompt` should be equal to batch_size.") - else: - uncond_tokens = negative_prompt - - max_length = text_input_ids.shape[-1] - uncond_input = self.tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) - uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] - - # duplicate unconditional embeddings for each generation per prompt - uncond_embeddings = uncond_embeddings.repeat_interleave(batch_size * num_images_per_prompt, dim=0) - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) - - # encode the init image into latents and scale the latents - latents_dtype = text_embeddings.dtype - init_image = init_image.to(device=self.device, dtype=latents_dtype) - init_latent_dist = self.vae.encode(init_image).latent_dist - init_latents = init_latent_dist.sample(generator=generator) - init_latents = 0.18215 * init_latents - - if isinstance(prompt, str): - prompt = [prompt] - if len(prompt) > init_latents.shape[0] and len(prompt) % init_latents.shape[0] == 0: - # expand init_latents for batch_size - deprecation_message = ( - f"You have passed {len(prompt)} text prompts (`prompt`), but only {init_latents.shape[0]} initial" - " images (`init_image`). Initial images are now duplicating to match the number of text prompts. Note" - " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" - " your script to pass as many init images as text prompts to suppress this warning." - ) - deprecate("len(prompt) != len(init_image)", "1.0.0", deprecation_message, standard_warn=False) - additional_image_per_prompt = len(prompt) // init_latents.shape[0] - init_latents = torch.cat([init_latents] * additional_image_per_prompt * num_images_per_prompt, dim=0) - elif len(prompt) > init_latents.shape[0] and len(prompt) % init_latents.shape[0] != 0: - raise ValueError( - f"Cannot duplicate `init_image` of batch size {init_latents.shape[0]} to {len(prompt)} text prompts." - ) - else: - init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0) - # get the original timestep using init_timestep - offset = self.scheduler.config.get("steps_offset", 0) - init_timestep = int(num_inference_steps * strength) + offset - init_timestep = min(init_timestep, num_inference_steps) - - timesteps = self.scheduler.timesteps[-init_timestep] - timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device) - - # add noise to latents using the timesteps - noise = torch.randn(init_latents.shape, generator=generator, device=self.device, dtype=latents_dtype) - init_latents = self.scheduler.add_noise(init_latents, noise, timesteps) + # 3. Encode input prompt + text_embeddings = self._encode_prompt( + prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ) - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta + # 4. Preprocess image + if isinstance(init_image, PIL.Image.Image): + init_image = preprocess(init_image) - latents = init_latents + # 5. set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.get_timesteps(num_inference_steps, strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) - t_start = max(num_inference_steps - init_timestep + offset, 0) + # 6. Prepare latent variables + latents = self.prepare_latents( + init_image, latent_timestep, batch_size, num_images_per_prompt, text_embeddings.dtype, device, generator + ) - # Some schedulers like PNDM have timesteps as arrays - # It's more optimized to move all timesteps to correct device beforehand - timesteps = self.scheduler.timesteps[t_start:].to(self.device) + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + # 8. Denoising loop for i, t in enumerate(self.progress_bar(timesteps)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents @@ -363,22 +604,13 @@ def __call__( if callback is not None and i % callback_steps == 0: callback(i, t, latents) - latents = 1 / 0.18215 * latents - image = self.vae.decode(latents).sample - - image = (image / 2 + 0.5).clamp(0, 1) - image = image.cpu().permute(0, 2, 3, 1).numpy() + # 9. Post-processing + image = self.decode_latents(latents) - if self.safety_checker is not None: - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to( - self.device - ) - image, has_nsfw_concept = self.safety_checker( - images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype) - ) - else: - has_nsfw_concept = None + # 10. Run safety checker + image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype) + # 11. Convert to PIL if output_type == "pil": image = self.numpy_to_pil(image) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index cb4d552a6f50..8cefffbb8eb1 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -1,3 +1,17 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import inspect from typing import Callable, List, Optional, Union @@ -5,7 +19,8 @@ import torch import PIL -from tqdm.auto import tqdm +from diffusers.utils import is_accelerate_available +from packaging import version from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from ...configuration_utils import FrozenDict @@ -17,30 +32,96 @@ from .safety_checker import StableDiffusionSafetyChecker -logger = logging.get_logger(__name__) +logger = logging.get_logger(__name__) # pylint: disable=invalid-name -def preprocess_image(image): - w, h = image.size - w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 - image = image.resize((w, h), resample=PIL.Image.LANCZOS) - image = np.array(image).astype(np.float32) / 255.0 - image = image[None].transpose(0, 3, 1, 2) - image = torch.from_numpy(image) - return 2.0 * image - 1.0 +def prepare_mask_and_masked_image(image, mask): + """ + Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be + converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the + ``image`` and ``1`` for the ``mask``. + The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be + binarized (``mask > 0.5``) and cast to ``torch.float32`` too. -def preprocess_mask(mask): - mask = mask.convert("L") - w, h = mask.size - w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 - mask = mask.resize((w // 8, h // 8), resample=PIL.Image.NEAREST) - mask = np.array(mask).astype(np.float32) / 255.0 - mask = np.tile(mask, (4, 1, 1)) - mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? - mask = 1 - mask # repaint white, keep black - mask = torch.from_numpy(mask) - return mask + Args: + image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint. + It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width`` + ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``. + mask (_type_): The mask to apply to the image, i.e. regions to inpaint. + It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width`` + ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``. + + + Raises: + ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask + should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions. + TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not + (ot the other way around). + + Returns: + tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4 + dimensions: ``batch x channels x height x width``. + """ + if isinstance(image, torch.Tensor): + if not isinstance(mask, torch.Tensor): + raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not") + + # Batch single image + if image.ndim == 3: + assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)" + image = image.unsqueeze(0) + + # Batch and add channel dim for single mask + if mask.ndim == 2: + mask = mask.unsqueeze(0).unsqueeze(0) + + # Batch single mask or add channel dim + if mask.ndim == 3: + # Single batched mask, no channel dim or single mask not batched but channel dim + if mask.shape[0] == 1: + mask = mask.unsqueeze(0) + + # Batched masks no channel dim + else: + mask = mask.unsqueeze(1) + + assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions" + assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions" + assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size" + + # Check image is in [-1, 1] + if image.min() < -1 or image.max() > 1: + raise ValueError("Image should be in [-1, 1] range") + + # Check mask is in [0, 1] + if mask.min() < 0 or mask.max() > 1: + raise ValueError("Mask should be in [0, 1] range") + + # Binarize mask + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + + # Image as float32 + image = image.to(dtype=torch.float32) + elif isinstance(mask, torch.Tensor): + raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not") + else: + if isinstance(image, PIL.Image.Image): + image = np.array(image.convert("RGB")) + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 + if isinstance(mask, PIL.Image.Image): + mask = np.array(mask.convert("L")) + mask = mask.astype(np.float32) / 255.0 + mask = mask[None, None] + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + mask = torch.from_numpy(mask) + + masked_image = image * (mask < 0.5) + + return mask, masked_image class StableDiffusionInpaintPipeline(DiffusionPipeline): @@ -62,14 +143,15 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. - Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPFeatureExtractor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ + _optional_components = ["safety_checker", "feature_extractor"] def __init__( self, @@ -80,9 +162,9 @@ def __init__( scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, + requires_safety_checker: bool = True, ): super().__init__() - logger.info("`StableDiffusionInpaintPipeline` is experimental and will very likely change in the future.") if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: deprecation_message = ( @@ -98,8 +180,22 @@ def __init__( new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) - if safety_checker is None: - logger.warn( + if hasattr(scheduler.config, "skip_prk_steps") and scheduler.config.skip_prk_steps is False: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration" + " `skip_prk_steps`. `skip_prk_steps` should be set to True in the configuration file. Please make" + " sure to update the config accordingly as not setting `skip_prk_steps` in the config might lead to" + " incorrect results in future versions. If you have downloaded this checkpoint from the Hugging Face" + " Hub, it would be very nice if you could open a Pull request for the" + " `scheduler/scheduler_config.json` file" + ) + deprecate("skip_prk_steps not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["skip_prk_steps"] = True + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" @@ -108,6 +204,33 @@ def __init__( " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + self.register_modules( vae=vae, text_encoder=text_encoder, @@ -117,7 +240,10 @@ def __init__( safety_checker=safety_checker, feature_extractor=feature_extractor, ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.register_to_config(requires_safety_checker=requires_safety_checker) + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): r""" Enable sliced attention computation. @@ -132,148 +258,145 @@ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto `attention_head_dim` must be a multiple of `slice_size`. """ if slice_size == "auto": - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = self.unet.config.attention_head_dim // 2 + if isinstance(self.unet.config.attention_head_dim, int): + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = self.unet.config.attention_head_dim // 2 + else: + # if `attention_head_dim` is a list, take the smallest head size + slice_size = min(self.unet.config.attention_head_dim) + self.unet.set_attention_slice(slice_size) + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing def disable_attention_slicing(self): r""" Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go back to computing attention in one step. """ - # set slice_size = `None` to disable `set_attention_slice` + # set slice_size = `None` to disable `attention slicing` self.enable_attention_slicing(None) - @torch.no_grad() - def __call__( - self, - prompt: Union[str, List[str]], - init_image: Union[torch.FloatTensor, PIL.Image.Image], - mask_image: Union[torch.FloatTensor, PIL.Image.Image], - strength: float = 0.8, - num_inference_steps: Optional[int] = 50, - guidance_scale: Optional[float] = 7.5, - negative_prompt: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: Optional[int] = 1, - eta: Optional[float] = 0.0, - generator: Optional[torch.Generator] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: Optional[int] = 1, - **kwargs, - ): + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload + def enable_sequential_cpu_offload(self, gpu_id=0): r""" - Function invoked when calling the pipeline for generation. + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") - Args: - prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. - init_image (`torch.FloatTensor` or `PIL.Image.Image`): - `Image`, or tensor representing an image batch, that will be used as the starting point for the - process. This is the image whose masked region will be inpainted. - mask_image (`torch.FloatTensor` or `PIL.Image.Image`): - `Image`, or tensor representing an image batch, to mask `init_image`. White pixels in the mask will be - replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a - PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should - contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`. - strength (`float`, *optional*, defaults to 0.8): - Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength` - is 1, the denoising process will be run on the masked area for the full number of iterations specified - in `num_inference_steps`. `init_image` will be used as a reference for the masked area, adding more - noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur. - num_inference_steps (`int`, *optional*, defaults to 50): - The reference number of denoising steps. More denoising steps usually lead to a higher quality image at - the expense of slower inference. This parameter will be modulated by `strength`, as explained above. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. + device = torch.device(f"cuda:{gpu_id}") - Returns: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + if self.safety_checker is not None: + # TODO(Patrick) - there is currently a bug with cpu offload of nn.Parameter in accelerate + # fix by only offloading self.safety_checker for now + cpu_offload(self.safety_checker.vision_model, device) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention + def enable_xformers_memory_efficient_attention(self): + r""" + Enable memory efficient attention as implemented in xformers. + + When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference + time. Speed up at training time is not guaranteed. + + Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention + is used. """ - if isinstance(prompt, str): - batch_size = 1 - elif isinstance(prompt, list): - batch_size = len(prompt) - else: - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + self.unet.set_use_memory_efficient_attention_xformers(True) - if strength < 0 or strength > 1: - raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_xformers_memory_efficient_attention + def disable_xformers_memory_efficient_attention(self): + r""" + Disable memory efficient attention as implemented in xformers. + """ + self.unet.set_use_memory_efficient_attention_xformers(False) - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): + r""" + Encodes the prompt into text encoder hidden states. - # set timesteps - self.scheduler.set_timesteps(num_inference_steps) + Args: + prompt (`str` or `list(int)`): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + """ + batch_size = len(prompt) if isinstance(prompt, list) else 1 - # get prompt text embeddings text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, + truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids - if text_input_ids.shape[-1] > self.tokenizer.model_max_length: - removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) + if not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) - text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] - text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] - # duplicate text embeddings for each generation per prompt - text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0) + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + text_embeddings = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + text_embeddings = text_embeddings[0] + + # duplicate text embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = text_embeddings.shape + text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) + text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance: uncond_tokens: List[str] if negative_prompt is None: - uncond_tokens = [""] + uncond_tokens = [""] * batch_size elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" @@ -298,74 +421,292 @@ def __call__( truncation=True, return_tensors="pt", ) - uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] - # duplicate unconditional embeddings for each generation per prompt - uncond_embeddings = uncond_embeddings.repeat_interleave(batch_size * num_images_per_prompt, dim=0) + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + uncond_embeddings = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + uncond_embeddings = uncond_embeddings[0] + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = uncond_embeddings.shape[1] + uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) + uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) - # preprocess image - if not isinstance(init_image, torch.FloatTensor): - init_image = preprocess_image(init_image) - - # encode the init image into latents and scale the latents - latents_dtype = text_embeddings.dtype - init_image = init_image.to(device=self.device, dtype=latents_dtype) - init_latent_dist = self.vae.encode(init_image).latent_dist - init_latents = init_latent_dist.sample(generator=generator) - init_latents = 0.18215 * init_latents - - # Expand init_latents for batch_size and num_images_per_prompt - init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0) - init_latents_orig = init_latents - - # preprocess mask - if not isinstance(mask_image, torch.FloatTensor): - mask_image = preprocess_mask(mask_image) - mask_image = mask_image.to(device=self.device, dtype=latents_dtype) - mask = torch.cat([mask_image] * batch_size * num_images_per_prompt) - - # check sizes - if not mask.shape == init_latents.shape: - raise ValueError("The mask and init_image should be the same size!") - - # get the original timestep using init_timestep - offset = self.scheduler.config.get("steps_offset", 0) - init_timestep = int(num_inference_steps * strength) + offset - init_timestep = min(init_timestep, num_inference_steps) - - timesteps = self.scheduler.timesteps[-init_timestep] - timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device) + return text_embeddings - # add noise to latents using the timesteps - noise = torch.randn(init_latents.shape, generator=generator, device=self.device, dtype=latents_dtype) - init_latents = self.scheduler.add_noise(init_latents, noise, timesteps) + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + else: + has_nsfw_concept = None + return image, has_nsfw_concept + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta - latents = init_latents + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + latents = 1 / 0.18215 * latents + image = self.vae.decode(latents).sample + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs + def check_inputs(self, prompt, height, width, callback_steps): + if not isinstance(prompt, str) and not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if latents is None: + if device.type == "mps": + # randn does not work reproducibly on mps + latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device) + else: + latents = torch.randn(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def prepare_mask_latents( + self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate( + mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) + ) + mask = mask.to(device=device, dtype=dtype) + + masked_image = masked_image.to(device=device, dtype=dtype) + + # encode the mask image into latents space so we can concatenate it to the latents + masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator) + masked_image_latents = 0.18215 * masked_image_latents + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + mask = mask.repeat(batch_size, 1, 1, 1) + masked_image_latents = masked_image_latents.repeat(batch_size, 1, 1, 1) + + mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask + masked_image_latents = ( + torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents + ) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + return mask, masked_image_latents + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + image: Union[torch.FloatTensor, PIL.Image.Image], + mask_image: Union[torch.FloatTensor, PIL.Image.Image], + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + image (`PIL.Image.Image`): + `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will + be masked out with `mask_image` and repainted according to `prompt`. + mask_image (`PIL.Image.Image`): + `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be + repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted + to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) + instead of 3, so the expected shape would be `(B, H, W, 1)`. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs + self.check_inputs(prompt, height, width, callback_steps) - t_start = max(num_inference_steps - init_timestep + offset, 0) + # 2. Define call parameters + batch_size = 1 if isinstance(prompt, str) else len(prompt) + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_embeddings = self._encode_prompt( + prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ) - # Some schedulers like PNDM have timesteps as arrays - # It's more optimized to move all timesteps to correct device beforehand - timesteps = self.scheduler.timesteps[t_start:].to(self.device) + # 4. Preprocess mask and image + if isinstance(image, PIL.Image.Image) and isinstance(mask_image, PIL.Image.Image): + mask, masked_image = prepare_mask_and_masked_image(image, mask_image) + + # 5. set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps_tensor = self.scheduler.timesteps + + # 6. Prepare latent variables + num_channels_latents = self.vae.config.latent_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + text_embeddings.dtype, + device, + generator, + latents, + ) + + # 7. Prepare mask latent variables + mask, masked_image_latents = self.prepare_mask_latents( + mask, + masked_image, + batch_size * num_images_per_prompt, + height, + width, + text_embeddings.dtype, + device, + generator, + do_classifier_free_guidance, + ) - for i, t in tqdm(enumerate(timesteps)): + # 8. Check that sizes of mask, masked image and latents match + num_channels_mask = mask.shape[1] + num_channels_masked_image = masked_image_latents.shape[1] + if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels: + raise ValueError( + f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" + f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" + f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" + " `pipeline.unet` or your `mask_image` or `image` input." + ) + + # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 10. Denoising loop + for i, t in enumerate(self.progress_bar(timesteps_tensor)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + # concat latents, mask, masked_image_latents in the channel dimension latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) # predict the noise residual noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample @@ -377,29 +718,18 @@ def __call__( # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - # masking - init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t])) - - latents = (init_latents_proper * mask) + (latents * (1 - mask)) # call the callback, if provided if callback is not None and i % callback_steps == 0: callback(i, t, latents) - latents = 1 / 0.18215 * latents - image = self.vae.decode(latents).sample - - image = (image / 2 + 0.5).clamp(0, 1) - image = image.cpu().permute(0, 2, 3, 1).numpy() + # 11. Post-processing + image = self.decode_latents(latents) - if self.safety_checker is not None: - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to( - self.device - ) - image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values) - else: - has_nsfw_concept = None + # 12. Run safety checker + image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype) + # 13. Convert to PIL if output_type == "pil": image = self.numpy_to_pil(image) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py new file mode 100644 index 000000000000..1d2c939fef49 --- /dev/null +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py @@ -0,0 +1,635 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable, List, Optional, Union + +import numpy as np +import torch + +import PIL +from diffusers.utils import is_accelerate_available +from packaging import version +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer + +from ...configuration_utils import FrozenDict +from ...models import AutoencoderKL, UNet2DConditionModel +from ...pipeline_utils import DiffusionPipeline +from ...schedulers import ( + DDIMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, +) +from ...utils import PIL_INTERPOLATION, deprecate, logging +from . import StableDiffusionPipelineOutput +from .safety_checker import StableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) + + +def preprocess_image(image): + w, h = image.size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + return 2.0 * image - 1.0 + + +def preprocess_mask(mask, scale_factor=8): + mask = mask.convert("L") + w, h = mask.size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"]) + mask = np.array(mask).astype(np.float32) / 255.0 + mask = np.tile(mask, (4, 1, 1)) + mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? + mask = 1 - mask # repaint white, keep black + mask = torch.from_numpy(mask) + return mask + + +class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): + r""" + Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + _optional_components = ["safety_checker", "feature_extractor"] + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.__init__ + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: Union[ + DDIMScheduler, + PNDMScheduler, + LMSDiscreteScheduler, + EulerDiscreteScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + ], + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPFeatureExtractor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, + `attention_head_dim` must be a multiple of `slice_size`. + """ + if slice_size == "auto": + if isinstance(self.unet.config.attention_head_dim, int): + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = self.unet.config.attention_head_dim // 2 + else: + # if `attention_head_dim` is a list, take the smallest head size + slice_size = min(self.unet.config.attention_head_dim) + + self.unet.set_attention_slice(slice_size) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing + def disable_attention_slicing(self): + r""" + Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go + back to computing attention in one step. + """ + # set slice_size = `None` to disable `attention slicing` + self.enable_attention_slicing(None) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + if self.safety_checker is not None: + # TODO(Patrick) - there is currently a bug with cpu offload of nn.Parameter in accelerate + # fix by only offloading self.safety_checker for now + cpu_offload(self.safety_checker.vision_model, device) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention + def enable_xformers_memory_efficient_attention(self): + r""" + Enable memory efficient attention as implemented in xformers. + + When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference + time. Speed up at training time is not guaranteed. + + Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention + is used. + """ + self.unet.set_use_memory_efficient_attention_xformers(True) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_xformers_memory_efficient_attention + def disable_xformers_memory_efficient_attention(self): + r""" + Disable memory efficient attention as implemented in xformers. + """ + self.unet.set_use_memory_efficient_attention_xformers(False) + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list(int)`): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + """ + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids + + if not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + text_embeddings = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + text_embeddings = text_embeddings[0] + + # duplicate text embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = text_embeddings.shape + text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) + text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + uncond_embeddings = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + uncond_embeddings = uncond_embeddings[0] + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = uncond_embeddings.shape[1] + uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) + uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + return text_embeddings + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + else: + has_nsfw_concept = None + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + latents = 1 / 0.18215 * latents + image = self.vae.decode(latents).sample + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.check_inputs + def check_inputs(self, prompt, strength, callback_steps): + if not isinstance(prompt, str) and not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [1.0, 1.0] but is {strength}") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + offset = self.scheduler.config.get("steps_offset", 0) + init_timestep = int(num_inference_steps * strength) + offset + init_timestep = min(init_timestep, num_inference_steps) + + t_start = max(num_inference_steps - init_timestep + offset, 0) + timesteps = self.scheduler.timesteps[t_start:] + + return timesteps + + def prepare_latents(self, init_image, timestep, batch_size, num_images_per_prompt, dtype, device, generator): + init_image = init_image.to(device=self.device, dtype=dtype) + init_latent_dist = self.vae.encode(init_image).latent_dist + init_latents = init_latent_dist.sample(generator=generator) + init_latents = 0.18215 * init_latents + + # Expand init_latents for batch_size and num_images_per_prompt + init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0) + init_latents_orig = init_latents + + # add noise to latents using the timesteps + noise = torch.randn(init_latents.shape, generator=generator, device=self.device, dtype=dtype) + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + latents = init_latents + return latents, init_latents_orig, noise + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + init_image: Union[torch.FloatTensor, PIL.Image.Image], + mask_image: Union[torch.FloatTensor, PIL.Image.Image], + strength: float = 0.8, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: Optional[float] = 0.0, + generator: Optional[torch.Generator] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + init_image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. This is the image whose masked region will be inpainted. + mask_image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, to mask `init_image`. White pixels in the mask will be + replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a + PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should + contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`. + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength` + is 1, the denoising process will be run on the masked area for the full number of iterations specified + in `num_inference_steps`. `init_image` will be used as a reference for the masked area, adding more + noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur. + num_inference_steps (`int`, *optional*, defaults to 50): + The reference number of denoising steps. More denoising steps usually lead to a higher quality image at + the expense of slower inference. This parameter will be modulated by `strength`, as explained above. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 1. Check inputs + self.check_inputs(prompt, strength, callback_steps) + + # 2. Define call parameters + batch_size = 1 if isinstance(prompt, str) else len(prompt) + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_embeddings = self._encode_prompt( + prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + # 4. Preprocess image and mask + if not isinstance(init_image, torch.FloatTensor): + init_image = preprocess_image(init_image) + + if not isinstance(mask_image, torch.FloatTensor): + mask_image = preprocess_mask(mask_image, self.vae_scale_factor) + + # 5. set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.get_timesteps(num_inference_steps, strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 6. Prepare latent variables + # encode the init image into latents and scale the latents + latents, init_latents_orig, noise = self.prepare_latents( + init_image, latent_timestep, batch_size, num_images_per_prompt, text_embeddings.dtype, device, generator + ) + + # 7. Prepare mask latent + mask = mask_image.to(device=self.device, dtype=latents.dtype) + mask = torch.cat([mask] * batch_size * num_images_per_prompt) + + # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 9. Denoising loop + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + # masking + init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t])) + + latents = (init_latents_proper * mask) + (latents * (1 - mask)) + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # 10. Post-processing + image = self.decode_latents(latents) + + # 11. Run safety checker + image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype) + + # 12. Convert to PIL + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py deleted file mode 100644 index acd6446af6fe..000000000000 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py +++ /dev/null @@ -1,200 +0,0 @@ -import inspect -from typing import Callable, List, Optional, Union - -import numpy as np - -from transformers import CLIPFeatureExtractor, CLIPTokenizer - -from ...onnx_utils import OnnxRuntimeModel -from ...pipeline_utils import DiffusionPipeline -from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler -from ...utils import logging -from . import StableDiffusionPipelineOutput - - -logger = logging.get_logger(__name__) - - -class StableDiffusionOnnxPipeline(DiffusionPipeline): - vae_decoder: OnnxRuntimeModel - text_encoder: OnnxRuntimeModel - tokenizer: CLIPTokenizer - unet: OnnxRuntimeModel - scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] - safety_checker: OnnxRuntimeModel - feature_extractor: CLIPFeatureExtractor - - def __init__( - self, - vae_decoder: OnnxRuntimeModel, - text_encoder: OnnxRuntimeModel, - tokenizer: CLIPTokenizer, - unet: OnnxRuntimeModel, - scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], - safety_checker: OnnxRuntimeModel, - feature_extractor: CLIPFeatureExtractor, - ): - super().__init__() - self.register_modules( - vae_decoder=vae_decoder, - text_encoder=text_encoder, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - ) - - def __call__( - self, - prompt: Union[str, List[str]], - height: Optional[int] = 512, - width: Optional[int] = 512, - num_inference_steps: Optional[int] = 50, - guidance_scale: Optional[float] = 7.5, - negative_prompt: Optional[Union[str, List[str]]] = None, - eta: Optional[float] = 0.0, - latents: Optional[np.ndarray] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, np.ndarray], None]] = None, - callback_steps: Optional[int] = 1, - **kwargs, - ): - if isinstance(prompt, str): - batch_size = 1 - elif isinstance(prompt, list): - batch_size = len(prompt) - else: - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) - - # get prompt text embeddings - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - return_tensors="np", - ) - text_input_ids = text_inputs.input_ids - - if text_input_ids.shape[-1] > self.tokenizer.model_max_length: - removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" - ) - text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] - text_embeddings = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] - - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance: - uncond_tokens: List[str] - if negative_prompt is None: - uncond_tokens = [""] * batch_size - elif type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt] * batch_size - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_tokens = negative_prompt - - max_length = text_input_ids.shape[-1] - uncond_input = self.tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="np", - ) - uncond_embeddings = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0] - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - text_embeddings = np.concatenate([uncond_embeddings, text_embeddings]) - - # get the initial random noise unless the user supplied it - latents_shape = (batch_size, 4, height // 8, width // 8) - if latents is None: - latents = np.random.randn(*latents_shape).astype(np.float32) - elif latents.shape != latents_shape: - raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") - - # set timesteps - self.scheduler.set_timesteps(num_inference_steps) - - latents = latents * self.scheduler.init_noise_sigma - - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): - # expand the latents if we are doing classifier free guidance - latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - # predict the noise residual - noise_pred = self.unet( - sample=latent_model_input, timestep=np.array([t]), encoder_hidden_states=text_embeddings - ) - noise_pred = noise_pred[0] - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - latents = np.array(latents) - - # call the callback, if provided - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) - - latents = 1 / 0.18215 * latents - image = self.vae_decoder(latent_sample=latents)[0] - - image = np.clip(image / 2 + 0.5, 0, 1) - image = image.transpose((0, 2, 3, 1)) - - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="np") - image, has_nsfw_concept = self.safety_checker(clip_input=safety_checker_input.pixel_values, images=image) - - if output_type == "pil": - image = self.numpy_to_pil(image) - - if not return_dict: - return (image, has_nsfw_concept) - - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py new file mode 100644 index 000000000000..7ccb43d46c14 --- /dev/null +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py @@ -0,0 +1,551 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable, List, Optional, Union + +import numpy as np +import torch + +import PIL +from diffusers.utils import is_accelerate_available +from transformers import CLIPTextModel, CLIPTokenizer + +from ...models import AutoencoderKL, UNet2DConditionModel +from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from ...schedulers import DDIMScheduler, DDPMScheduler, LMSDiscreteScheduler, PNDMScheduler +from ...utils import logging + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def preprocess(image): + # resize to multiple of 64 + width, height = image.size + width = width - width % 64 + height = height - height % 64 + image = image.resize((width, height)) + + image = np.array(image.convert("RGB")) + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 + return image + + +class StableDiffusionUpscalePipeline(DiffusionPipeline): + r""" + Pipeline for text-guided image super-resolution using Stable Diffusion 2. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + low_res_scheduler ([`SchedulerMixin`]): + A scheduler used to add initial noise to the low res conditioning image. It must be an instance of + [`DDPMScheduler`]. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + low_res_scheduler: DDPMScheduler, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + max_noise_level: int = 350, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + low_res_scheduler=low_res_scheduler, + scheduler=scheduler, + ) + self.register_to_config(max_noise_level=max_noise_level) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, + `attention_head_dim` must be a multiple of `slice_size`. + """ + if slice_size == "auto": + if isinstance(self.unet.config.attention_head_dim, int): + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = self.unet.config.attention_head_dim // 2 + else: + # if `attention_head_dim` is a list, take the smallest head size + slice_size = min(self.unet.config.attention_head_dim) + + self.unet.set_attention_slice(slice_size) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing + def disable_attention_slicing(self): + r""" + Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go + back to computing attention in one step. + """ + # set slice_size = `None` to disable `attention slicing` + self.enable_attention_slicing(None) + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + for cpu_offloaded_model in [self.unet, self.text_encoder]: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention + def enable_xformers_memory_efficient_attention(self): + r""" + Enable memory efficient attention as implemented in xformers. + + When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference + time. Speed up at training time is not guaranteed. + + Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention + is used. + """ + self.unet.set_use_memory_efficient_attention_xformers(True) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_xformers_memory_efficient_attention + def disable_xformers_memory_efficient_attention(self): + r""" + Disable memory efficient attention as implemented in xformers. + """ + self.unet.set_use_memory_efficient_attention_xformers(False) + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list(int)`): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + """ + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids + + if not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + text_embeddings = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + text_embeddings = text_embeddings[0] + + # duplicate text embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = text_embeddings.shape + text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) + text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + uncond_embeddings = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + uncond_embeddings = uncond_embeddings[0] + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = uncond_embeddings.shape[1] + uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) + uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + return text_embeddings + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents with 0.18215->0.08333 + def decode_latents(self, latents): + latents = 1 / 0.08333 * latents + image = self.vae.decode(latents).sample + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def check_inputs(self, prompt, image, noise_level, callback_steps): + if not isinstance(prompt, str) and not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if ( + not isinstance(image, torch.Tensor) + and not isinstance(image, PIL.Image.Image) + and not isinstance(image, list) + ): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or `list` but is {type(image)}" + ) + + # verify batch size of prompt and image are same if image is a list or tensor + if isinstance(image, list) or isinstance(image, torch.Tensor): + if isinstance(prompt, str): + batch_size = 1 + else: + batch_size = len(prompt) + if isinstance(image, list): + image_batch_size = len(image) + else: + image_batch_size = image.shape[0] + if batch_size != image_batch_size: + raise ValueError( + f"`prompt` has batch size {batch_size} and `image` has batch size {image_batch_size}." + " Please make sure that passed `prompt` matches the batch size of `image`." + ) + + # check noise level + if noise_level > self.config.max_noise_level: + raise ValueError(f"`noise_level` has to be <= {self.config.max_noise_level} but is {noise_level}") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height, width) + if latents is None: + if device.type == "mps": + # randn does not work reproducibly on mps + latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device) + else: + latents = torch.randn(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]], + num_inference_steps: int = 75, + guidance_scale: float = 9.0, + noise_level: int = 20, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + image (`PIL.Image.Image` or List[`PIL.Image.Image`] or `torch.FloatTensor`): + `Image`, or tensor representing an image batch which will be upscaled. * + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + + # 1. Check inputs + self.check_inputs(prompt, image, noise_level, callback_steps) + + # 2. Define call parameters + batch_size = 1 if isinstance(prompt, str) else len(prompt) + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_embeddings = self._encode_prompt( + prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + # 4. Preprocess image + image = [image] if isinstance(image, PIL.Image.Image) else image + if isinstance(image, list): + image = [preprocess(img) for img in image] + image = torch.cat(image, dim=0) + image = image.to(dtype=text_embeddings.dtype, device=device) + + # 5. set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps_tensor = self.scheduler.timesteps + + # 5. Add noise to image + noise_level = torch.tensor([noise_level], dtype=torch.long, device=device) + if device.type == "mps": + # randn does not work reproducibly on mps + noise = torch.randn(image.shape, generator=generator, device="cpu", dtype=text_embeddings.dtype).to(device) + else: + noise = torch.randn(image.shape, generator=generator, device=device, dtype=text_embeddings.dtype) + image = self.low_res_scheduler.add_noise(image, noise, noise_level) + image = torch.cat([image] * 2) if do_classifier_free_guidance else image + noise_level = torch.cat([noise_level] * 2) if do_classifier_free_guidance else noise_level + + # 6. Prepare latent variables + height, width = image.shape[2:] + num_channels_latents = self.vae.config.latent_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + text_embeddings.dtype, + device, + generator, + latents, + ) + + # 7. Check that sizes of image and latents match + num_channels_image = image.shape[1] + if num_channels_latents + num_channels_image != self.unet.config.in_channels: + raise ValueError( + f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" + f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_image`: {num_channels_image} " + f" = {num_channels_latents+num_channels_image}. Please verify the config of" + " `pipeline.unet` or your `image` input." + ) + + # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 9. Denoising loop + for i, t in enumerate(self.progress_bar(timesteps_tensor)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + # concat latents, mask, masked_image_latents in the channel dimension + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + latent_model_input = torch.cat([latent_model_input, image], dim=1) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, t, encoder_hidden_states=text_embeddings, class_labels=noise_level + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # 10. Post-processing + # make sure the VAE is in float32 mode, as it overflows in float16 + self.vae.to(dtype=torch.float32) + image = self.decode_latents(latents.float()) + + # 11. Convert to PIL + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/src/diffusers/pipelines/stable_diffusion/safety_checker.py b/src/diffusers/pipelines/stable_diffusion/safety_checker.py index 09d7a3bbf95a..0477c983ead4 100644 --- a/src/diffusers/pipelines/stable_diffusion/safety_checker.py +++ b/src/diffusers/pipelines/stable_diffusion/safety_checker.py @@ -1,3 +1,17 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import numpy as np import torch import torch.nn as nn diff --git a/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py b/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py index b3cd8eef02fa..e1f669d22b76 100644 --- a/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py +++ b/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py @@ -1,7 +1,18 @@ -import warnings -from typing import Optional, Tuple +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. -import numpy as np +from typing import Optional, Tuple import jax import jax.numpy as jnp @@ -39,56 +50,22 @@ def __call__(self, clip_input): special_cos_dist = jax_cosine_distance(image_embeds, self.special_care_embeds) cos_dist = jax_cosine_distance(image_embeds, self.concept_embeds) - return special_cos_dist, cos_dist - - def filtered_with_scores(self, special_cos_dist, cos_dist, images): - batch_size = special_cos_dist.shape[0] - special_cos_dist = np.asarray(special_cos_dist) - cos_dist = np.asarray(cos_dist) - - result = [] - for i in range(batch_size): - result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []} - - # increase this value to create a stronger `nfsw` filter - # at the cost of increasing the possibility of filtering benign image inputs - adjustment = 0.0 - - for concept_idx in range(len(special_cos_dist[0])): - concept_cos = special_cos_dist[i][concept_idx] - concept_threshold = self.special_care_embeds_weights[concept_idx].item() - result_img["special_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3) - if result_img["special_scores"][concept_idx] > 0: - result_img["special_care"].append({concept_idx, result_img["special_scores"][concept_idx]}) - adjustment = 0.01 - - for concept_idx in range(len(cos_dist[0])): - concept_cos = cos_dist[i][concept_idx] - concept_threshold = self.concept_embeds_weights[concept_idx].item() - result_img["concept_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3) - if result_img["concept_scores"][concept_idx] > 0: - result_img["bad_concepts"].append(concept_idx) - result.append(result_img) + # increase this value to create a stronger `nfsw` filter + # at the cost of increasing the possibility of filtering benign image inputs + adjustment = 0.0 - has_nsfw_concepts = [len(res["bad_concepts"]) > 0 for res in result] + special_scores = special_cos_dist - self.special_care_embeds_weights[None, :] + adjustment + special_scores = jnp.round(special_scores, 3) + is_special_care = jnp.any(special_scores > 0, axis=1, keepdims=True) + # Use a lower threshold if an image has any special care concept + special_adjustment = is_special_care * 0.01 - images_was_copied = False - for idx, has_nsfw_concept in enumerate(has_nsfw_concepts): - if has_nsfw_concept: - if not images_was_copied: - images_was_copied = True - images = images.copy() + concept_scores = cos_dist - self.concept_embeds_weights[None, :] + special_adjustment + concept_scores = jnp.round(concept_scores, 3) + has_nsfw_concepts = jnp.any(concept_scores > 0, axis=1) - images[idx] = np.zeros(images[idx].shape) # black image - - if any(has_nsfw_concepts): - warnings.warn( - "Potential NSFW content was detected in one or more images. A black image will be returned" - " instead. Try again with a different prompt and/or seed." - ) - - return images, has_nsfw_concepts + return has_nsfw_concepts class FlaxStableDiffusionSafetyChecker(FlaxPreTrainedModel): @@ -133,15 +110,3 @@ def __call__( jnp.array(clip_input, dtype=jnp.float32), rngs={}, ) - - def filtered_with_scores(self, special_cos_dist, cos_dist, images, params: dict = None): - def _filtered_with_scores(module, special_cos_dist, cos_dist, images): - return module.filtered_with_scores(special_cos_dist, cos_dist, images) - - return self.module.apply( - {"params": params or self.params}, - special_cos_dist, - cos_dist, - images, - method=_filtered_with_scores, - ) diff --git a/src/diffusers/pipelines/stable_diffusion_safe/__init__.py b/src/diffusers/pipelines/stable_diffusion_safe/__init__.py new file mode 100644 index 000000000000..59ff61fa3b54 --- /dev/null +++ b/src/diffusers/pipelines/stable_diffusion_safe/__init__.py @@ -0,0 +1,72 @@ +from dataclasses import dataclass +from enum import Enum +from typing import List, Optional, Union + +import numpy as np + +import PIL +from PIL import Image + +from ...utils import BaseOutput, is_torch_available, is_transformers_available + + +@dataclass +class SafetyConfig(object): + WEAK = { + "sld_warmup_steps": 15, + "sld_guidance_scale": 20, + "sld_threshold": 0.0, + "sld_momentum_scale": 0.0, + "sld_mom_beta": 0.0, + } + MEDIUM = { + "sld_warmup_steps": 10, + "sld_guidance_scale": 1000, + "sld_threshold": 0.01, + "sld_momentum_scale": 0.3, + "sld_mom_beta": 0.4, + } + STRONG = { + "sld_warmup_steps": 7, + "sld_guidance_scale": 2000, + "sld_threshold": 0.025, + "sld_momentum_scale": 0.5, + "sld_mom_beta": 0.7, + } + MAX = { + "sld_warmup_steps": 0, + "sld_guidance_scale": 5000, + "sld_threshold": 1.0, + "sld_momentum_scale": 0.5, + "sld_mom_beta": 0.7, + } + + +@dataclass +class StableDiffusionSafePipelineOutput(BaseOutput): + """ + Output class for Safe Stable Diffusion pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + nsfw_content_detected (`List[bool]`) + List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, or `None` if safety checking could not be performed. + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images that were flagged by the safety checker any may contain "not-safe-for-work" + (nsfw) content, or `None` if no safety check was performed or no images were flagged. + applied_safety_concept (`str`) + The safety concept that was applied for safety guidance, or `None` if safety guidance was disabled + """ + + images: Union[List[PIL.Image.Image], np.ndarray] + nsfw_content_detected: Optional[List[bool]] + unsafe_images: Optional[Union[List[PIL.Image.Image], np.ndarray]] + applied_safety_concept: Optional[str] + + +if is_transformers_available() and is_torch_available(): + from .pipeline_stable_diffusion_safe import StableDiffusionPipelineSafe + from .safety_checker import SafeStableDiffusionSafetyChecker diff --git a/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py b/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py new file mode 100644 index 000000000000..7948bbecf8ec --- /dev/null +++ b/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py @@ -0,0 +1,757 @@ +import inspect +import warnings +from typing import Callable, List, Optional, Union + +import numpy as np +import torch + +from packaging import version +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer + +from ...configuration_utils import FrozenDict +from ...models import AutoencoderKL, UNet2DConditionModel +from ...pipeline_utils import DiffusionPipeline +from ...schedulers import ( + DDIMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, +) +from ...utils import deprecate, is_accelerate_available, logging +from . import StableDiffusionSafePipelineOutput +from .safety_checker import SafeStableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class StableDiffusionPipelineSafe(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using Safe Latent Diffusion. + + The implementation is based on the [`StableDiffusionPipeline`] + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: Union[ + DDIMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, + ], + safety_checker: SafeStableDiffusionSafetyChecker, + feature_extractor: CLIPFeatureExtractor, + requires_safety_checker: bool = True, + ): + super().__init__() + safety_concept: Optional[str] = ( + "an image showing hate, harassment, violence, suffering, humiliation, harm, suicide, sexual, nudity," + " bodily fluids, blood, obscene gestures, illegal activity, drug use, theft, vandalism, weapons, child" + " abuse, brutality, cruelty" + ) + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self._safety_text_concept = safety_concept + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + @property + def safety_concept(self): + r""" + Getter method for the safety concept used with SLD + + Returns: + `str`: The text describing the safety concept + """ + return self._safety_text_concept + + @safety_concept.setter + def safety_concept(self, concept): + r""" + Setter method for the safety concept used with SLD + + Args: + concept (`str`): + The text of the new safety concept + """ + self._safety_text_concept = concept + + def enable_xformers_memory_efficient_attention(self): + r""" + Enable memory efficient attention as implemented in xformers. + + When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference + time. Speed up at training time is not guaranteed. + + Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention + is used. + """ + self.unet.set_use_memory_efficient_attention_xformers(True) + + def disable_xformers_memory_efficient_attention(self): + r""" + Disable memory efficient attention as implemented in xformers. + """ + self.unet.set_use_memory_efficient_attention_xformers(False) + + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, + `attention_head_dim` must be a multiple of `slice_size`. + """ + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = self.unet.config.attention_head_dim // 2 + self.unet.set_attention_slice(slice_size) + + def disable_attention_slicing(self): + r""" + Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go + back to computing attention in one step. + """ + # set slice_size = `None` to disable `attention slicing` + self.enable_attention_slicing(None) + + def enable_sequential_cpu_offload(self): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device("cuda") + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + enable_safety_guidance, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list(int)`): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + """ + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids + + if not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + text_embeddings = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + text_embeddings = text_embeddings[0] + + # duplicate text embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = text_embeddings.shape + text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) + text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + uncond_embeddings = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + uncond_embeddings = uncond_embeddings[0] + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = uncond_embeddings.shape[1] + uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) + uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) + + # Encode the safety concept text + if enable_safety_guidance: + safety_concept_input = self.tokenizer( + [self._safety_text_concept], + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + safety_embeddings = self.text_encoder(safety_concept_input.input_ids.to(self.device))[0] + + # duplicate safety embeddings for each generation per prompt, using mps friendly method + seq_len = safety_embeddings.shape[1] + safety_embeddings = safety_embeddings.repeat(batch_size, num_images_per_prompt, 1) + safety_embeddings = safety_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance + sld, we need to do three forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing three forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings, safety_embeddings]) + + else: + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + return text_embeddings + + def run_safety_checker(self, image, device, dtype, enable_safety_guidance): + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + flagged_images = None + if any(has_nsfw_concept): + logger.warning( + "Potential NSFW content was detected in one or more images. A black image will be returned" + " instead." + f" {'You may look at this images in the `unsafe_images` variable of the output at your own discretion.' if enable_safety_guidance else 'Try again with a different prompt and/or seed.'} " + ) + flagged_images = np.zeros((2, *image.shape[1:])) + for idx, has_nsfw_concept in enumerate(has_nsfw_concept): + if has_nsfw_concept: + flagged_images[idx] = image[idx] + image[idx] = np.zeros(image[idx].shape) # black image + else: + has_nsfw_concept = None + flagged_images = None + return image, has_nsfw_concept, flagged_images + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + latents = 1 / 0.18215 * latents + image = self.vae.decode(latents).sample + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs + def check_inputs(self, prompt, height, width, callback_steps): + if not isinstance(prompt, str) and not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if latents is None: + if device.type == "mps": + # randn does not work reproducibly on mps + latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device) + else: + latents = torch.randn(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def perform_safety_guidance( + self, + enable_safety_guidance, + safety_momentum, + noise_guidance, + noise_pred_out, + i, + sld_guidance_scale, + sld_warmup_steps, + sld_threshold, + sld_momentum_scale, + sld_mom_beta, + ): + # Perform SLD guidance + if enable_safety_guidance: + if safety_momentum is None: + safety_momentum = torch.zeros_like(noise_guidance) + noise_pred_text, noise_pred_uncond = noise_pred_out[0], noise_pred_out[1] + noise_pred_safety_concept = noise_pred_out[2] + + # Equation 6 + scale = torch.clamp(torch.abs((noise_pred_text - noise_pred_safety_concept)) * sld_guidance_scale, max=1.0) + + # Equation 6 + safety_concept_scale = torch.where( + (noise_pred_text - noise_pred_safety_concept) >= sld_threshold, torch.zeros_like(scale), scale + ) + + # Equation 4 + noise_guidance_safety = torch.mul((noise_pred_safety_concept - noise_pred_uncond), safety_concept_scale) + + # Equation 7 + noise_guidance_safety = noise_guidance_safety + sld_momentum_scale * safety_momentum + + # Equation 8 + safety_momentum = sld_mom_beta * safety_momentum + (1 - sld_mom_beta) * noise_guidance_safety + + if i >= sld_warmup_steps: # Warmup + # Equation 3 + noise_guidance = noise_guidance - noise_guidance_safety + return noise_guidance, safety_momentum + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + sld_guidance_scale: Optional[float] = 1000, + sld_warmup_steps: Optional[int] = 10, + sld_threshold: Optional[float] = 0.01, + sld_momentum_scale: Optional[float] = 0.3, + sld_mom_beta: Optional[float] = 0.4, + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + sld_guidance_scale (`float`, *optional*, defaults to 1000): + Safe latent guidance as defined in [Safe Latent Diffusion](https://arxiv.org/abs/2211.05105). + `sld_guidance_scale` is defined as sS of Eq. 6. If set to be less than 1, safety guidance will be + disabled. + sld_warmup_steps (`int`, *optional*, defaults to 10): + Number of warmup steps for safety guidance. SLD will only be applied for diffusion steps greater than + `sld_warmup_steps`. `sld_warmup_steps` is defined as `delta` of [Safe Latent + Diffusion](https://arxiv.org/abs/2211.05105). + sld_threshold (`float`, *optional*, defaults to 0.01): + Threshold that separates the hyperplane between appropriate and inappropriate images. `sld_threshold` + is defined as `lamda` of Eq. 5 in [Safe Latent Diffusion](https://arxiv.org/abs/2211.05105). + sld_momentum_scale (`float`, *optional*, defaults to 0.3): + Scale of the SLD momentum to be added to the safety guidance at each diffusion step. If set to 0.0 + momentum will be disabled. Momentum is already built up during warmup, i.e. for diffusion steps smaller + than `sld_warmup_steps`. `sld_momentum_scale` is defined as `sm` of Eq. 7 in [Safe Latent + Diffusion](https://arxiv.org/abs/2211.05105). + sld_mom_beta (`float`, *optional*, defaults to 0.4): + Defines how safety guidance momentum builds up. `sld_mom_beta` indicates how much of the previous + momentum will be kept. Momentum is already built up during warmup, i.e. for diffusion steps smaller + than `sld_warmup_steps`. `sld_mom_beta` is defined as `beta m` of Eq. 8 in [Safe Latent + Diffusion](https://arxiv.org/abs/2211.05105). + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, height, width, callback_steps) + + # 2. Define call parameters + batch_size = 1 if isinstance(prompt, str) else len(prompt) + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + enable_safety_guidance = sld_guidance_scale > 1.0 and do_classifier_free_guidance + if not enable_safety_guidance: + warnings.warn("Safety checker disabled!") + + # 3. Encode input prompt + text_embeddings = self._encode_prompt( + prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, enable_safety_guidance + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + text_embeddings.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + safety_momentum = None + + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = ( + torch.cat([latents] * (3 if enable_safety_guidance else 2)) if do_classifier_free_guidance else latents + ) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_out = noise_pred.chunk((3 if enable_safety_guidance else 2)) + noise_pred_uncond, noise_pred_text = noise_pred_out[0], noise_pred_out[1] + + # default classifier free guidance + noise_guidance = noise_pred_text - noise_pred_uncond + + # Perform SLD guidance + if enable_safety_guidance: + if safety_momentum is None: + safety_momentum = torch.zeros_like(noise_guidance) + noise_pred_safety_concept = noise_pred_out[2] + + # Equation 6 + scale = torch.clamp( + torch.abs((noise_pred_text - noise_pred_safety_concept)) * sld_guidance_scale, max=1.0 + ) + + # Equation 6 + safety_concept_scale = torch.where( + (noise_pred_text - noise_pred_safety_concept) >= sld_threshold, torch.zeros_like(scale), scale + ) + + # Equation 4 + noise_guidance_safety = torch.mul( + (noise_pred_safety_concept - noise_pred_uncond), safety_concept_scale + ) + + # Equation 7 + noise_guidance_safety = noise_guidance_safety + sld_momentum_scale * safety_momentum + + # Equation 8 + safety_momentum = sld_mom_beta * safety_momentum + (1 - sld_mom_beta) * noise_guidance_safety + + if i >= sld_warmup_steps: # Warmup + # Equation 3 + noise_guidance = noise_guidance - noise_guidance_safety + + noise_pred = noise_pred_uncond + guidance_scale * noise_guidance + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # 8. Post-processing + image = self.decode_latents(latents) + + # 9. Run safety checker + image, has_nsfw_concept, flagged_images = self.run_safety_checker( + image, device, text_embeddings.dtype, enable_safety_guidance + ) + + # 10. Convert to PIL + if output_type == "pil": + image = self.numpy_to_pil(image) + if flagged_images is not None: + flagged_images = self.numpy_to_pil(flagged_images) + + if not return_dict: + return ( + image, + has_nsfw_concept, + self._safety_text_concept if enable_safety_guidance else None, + flagged_images, + ) + + return StableDiffusionSafePipelineOutput( + images=image, + nsfw_content_detected=has_nsfw_concept, + applied_safety_concept=self._safety_text_concept if enable_safety_guidance else None, + unsafe_images=flagged_images, + ) diff --git a/src/diffusers/pipelines/stable_diffusion_safe/safety_checker.py b/src/diffusers/pipelines/stable_diffusion_safe/safety_checker.py new file mode 100644 index 000000000000..f9dbf51e8644 --- /dev/null +++ b/src/diffusers/pipelines/stable_diffusion_safe/safety_checker.py @@ -0,0 +1,110 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn + +from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel + +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +def cosine_distance(image_embeds, text_embeds): + normalized_image_embeds = nn.functional.normalize(image_embeds) + normalized_text_embeds = nn.functional.normalize(text_embeds) + return torch.mm(normalized_image_embeds, normalized_text_embeds.t()) + + +class SafeStableDiffusionSafetyChecker(PreTrainedModel): + config_class = CLIPConfig + + _no_split_modules = ["CLIPEncoderLayer"] + + def __init__(self, config: CLIPConfig): + super().__init__(config) + + self.vision_model = CLIPVisionModel(config.vision_config) + self.visual_projection = nn.Linear(config.vision_config.hidden_size, config.projection_dim, bias=False) + + self.concept_embeds = nn.Parameter(torch.ones(17, config.projection_dim), requires_grad=False) + self.special_care_embeds = nn.Parameter(torch.ones(3, config.projection_dim), requires_grad=False) + + self.concept_embeds_weights = nn.Parameter(torch.ones(17), requires_grad=False) + self.special_care_embeds_weights = nn.Parameter(torch.ones(3), requires_grad=False) + + @torch.no_grad() + def forward(self, clip_input, images): + pooled_output = self.vision_model(clip_input)[1] # pooled_output + image_embeds = self.visual_projection(pooled_output) + + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu().float().numpy() + cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu().float().numpy() + + result = [] + batch_size = image_embeds.shape[0] + for i in range(batch_size): + result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []} + + # increase this value to create a stronger `nfsw` filter + # at the cost of increasing the possibility of filtering benign images + adjustment = 0.0 + + for concept_idx in range(len(special_cos_dist[0])): + concept_cos = special_cos_dist[i][concept_idx] + concept_threshold = self.special_care_embeds_weights[concept_idx].item() + result_img["special_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3) + if result_img["special_scores"][concept_idx] > 0: + result_img["special_care"].append({concept_idx, result_img["special_scores"][concept_idx]}) + adjustment = 0.01 + + for concept_idx in range(len(cos_dist[0])): + concept_cos = cos_dist[i][concept_idx] + concept_threshold = self.concept_embeds_weights[concept_idx].item() + result_img["concept_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3) + if result_img["concept_scores"][concept_idx] > 0: + result_img["bad_concepts"].append(concept_idx) + + result.append(result_img) + + has_nsfw_concepts = [len(res["bad_concepts"]) > 0 for res in result] + + return images, has_nsfw_concepts + + @torch.no_grad() + def forward_onnx(self, clip_input: torch.FloatTensor, images: torch.FloatTensor): + pooled_output = self.vision_model(clip_input)[1] # pooled_output + image_embeds = self.visual_projection(pooled_output) + + special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds) + cos_dist = cosine_distance(image_embeds, self.concept_embeds) + + # increase this value to create a stronger `nsfw` filter + # at the cost of increasing the possibility of filtering benign images + adjustment = 0.0 + + special_scores = special_cos_dist - self.special_care_embeds_weights + adjustment + # special_scores = special_scores.round(decimals=3) + special_care = torch.any(special_scores > 0, dim=1) + special_adjustment = special_care * 0.01 + special_adjustment = special_adjustment.unsqueeze(1).expand(-1, cos_dist.shape[1]) + + concept_scores = (cos_dist - self.concept_embeds_weights) + special_adjustment + # concept_scores = concept_scores.round(decimals=3) + has_nsfw_concepts = torch.any(concept_scores > 0, dim=1) + + return images, has_nsfw_concepts diff --git a/src/diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py b/src/diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py index 9e8864b4ca76..739de8ebe620 100644 --- a/src/diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +++ b/src/diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py @@ -1,4 +1,17 @@ -#!/usr/bin/env python3 +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import Optional, Tuple, Union import torch diff --git a/src/diffusers/pipelines/versatile_diffusion/__init__.py b/src/diffusers/pipelines/versatile_diffusion/__init__.py new file mode 100644 index 000000000000..1d2caa7e2399 --- /dev/null +++ b/src/diffusers/pipelines/versatile_diffusion/__init__.py @@ -0,0 +1,16 @@ +from ...utils import is_torch_available, is_transformers_available, is_transformers_version + + +if is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0.dev0"): + from .modeling_text_unet import UNetFlatConditionModel + from .pipeline_versatile_diffusion import VersatileDiffusionPipeline + from .pipeline_versatile_diffusion_dual_guided import VersatileDiffusionDualGuidedPipeline + from .pipeline_versatile_diffusion_image_variation import VersatileDiffusionImageVariationPipeline + from .pipeline_versatile_diffusion_text_to_image import VersatileDiffusionTextToImagePipeline +else: + from ...utils.dummy_torch_and_transformers_objects import ( + VersatileDiffusionDualGuidedPipeline, + VersatileDiffusionImageVariationPipeline, + VersatileDiffusionPipeline, + VersatileDiffusionTextToImagePipeline, + ) diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py new file mode 100644 index 000000000000..37a79b5c1be7 --- /dev/null +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -0,0 +1,1135 @@ +from typing import Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...modeling_utils import ModelMixin +from ...models.attention import DualTransformer2DModel, Transformer2DModel +from ...models.embeddings import TimestepEmbedding, Timesteps +from ...models.unet_2d_condition import UNet2DConditionOutput +from ...utils import logging + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def get_down_block( + down_block_type, + num_layers, + in_channels, + out_channels, + temb_channels, + add_downsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + downsample_padding=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, +): + down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type + if down_block_type == "DownBlockFlat": + return DownBlockFlat( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + ) + elif down_block_type == "CrossAttnDownBlockFlat": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockFlat") + return CrossAttnDownBlockFlat( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + ) + raise ValueError(f"{down_block_type} is not supported.") + + +def get_up_block( + up_block_type, + num_layers, + in_channels, + out_channels, + prev_output_channel, + temb_channels, + add_upsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, +): + up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type + if up_block_type == "UpBlockFlat": + return UpBlockFlat( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + ) + elif up_block_type == "CrossAttnUpBlockFlat": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockFlat") + return CrossAttnUpBlockFlat( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + ) + raise ValueError(f"{up_block_type} is not supported.") + + +# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel with UNet2DConditionModel->UNetFlatConditionModel, nn.Conv2d->LinearMultiDim, Block2D->BlockFlat +class UNetFlatConditionModel(ModelMixin, ConfigMixin): + r""" + UNetFlatConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a + timestep and returns sample shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library + implements for all the models (such as downloading or saving, etc.) + + Parameters: + sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. + in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. + center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. + flip_sin_to_cos (`bool`, *optional*, defaults to `False`): + Whether to flip the sin to cos in the time embedding. + freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. + down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "DownBlockFlat")`): + The tuple of downsample blocks to use. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat",)`): + The tuple of upsample blocks to use. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. + norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. + cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features. + attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 4, + out_channels: int = 4, + center_input_sample: bool = False, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlockFlat", + "CrossAttnDownBlockFlat", + "CrossAttnDownBlockFlat", + "DownBlockFlat", + ), + up_block_types: Tuple[str] = ( + "UpBlockFlat", + "CrossAttnUpBlockFlat", + "CrossAttnUpBlockFlat", + "CrossAttnUpBlockFlat", + ), + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: int = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: int = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int = 1280, + attention_head_dim: Union[int, Tuple[int]] = 8, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + num_class_embeds: Optional[int] = None, + ): + super().__init__() + + self.sample_size = sample_size + time_embed_dim = block_out_channels[0] * 4 + + # input + self.conv_in = LinearMultiDim(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)) + + # time + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + + self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + + # class embedding + if num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + + self.down_blocks = nn.ModuleList([]) + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + + if isinstance(only_cross_attention, bool): + only_cross_attention = [only_cross_attention] * len(down_block_types) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[i], + downsample_padding=downsample_padding, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + ) + self.down_blocks.append(down_block) + + # mid + self.mid_block = UNetMidBlockFlatCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift="default", + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + ) + + # count how many layers upsample the images + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_attention_head_dim = list(reversed(attention_head_dim)) + only_cross_attention = list(reversed(only_cross_attention)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=reversed_attention_head_dim[i], + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) + self.conv_act = nn.SiLU() + self.conv_out = LinearMultiDim(block_out_channels[0], out_channels, kernel_size=3, padding=1) + + def set_attention_slice(self, slice_size): + head_dims = self.config.attention_head_dim + head_dims = [head_dims] if isinstance(head_dims, int) else head_dims + if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims): + raise ValueError( + f"Make sure slice_size {slice_size} is a common divisor of " + f"the number of heads used in cross_attention: {head_dims}" + ) + if slice_size is not None and slice_size > min(head_dims): + raise ValueError( + f"slice_size {slice_size} has to be smaller or equal to " + f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}" + ) + + for block in self.down_blocks: + if hasattr(block, "attentions") and block.attentions is not None: + block.set_attention_slice(slice_size) + + self.mid_block.set_attention_slice(slice_size) + + for block in self.up_blocks: + if hasattr(block, "attentions") and block.attentions is not None: + block.set_attention_slice(slice_size) + + def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): + for block in self.down_blocks: + if hasattr(block, "attentions") and block.attentions is not None: + block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) + + self.mid_block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) + + for block in self.up_blocks: + if hasattr(block, "attentions") and block.attentions is not None: + block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (CrossAttnDownBlockFlat, DownBlockFlat, CrossAttnUpBlockFlat, UpBlockFlat)): + module.gradient_checkpointing = value + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[UNet2DConditionOutput, Tuple]: + r""" + Args: + sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor + timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps + encoder_hidden_states (`torch.FloatTensor`): (batch, channel, height, width) encoder hidden states + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + + Returns: + [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # 0. center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) + elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=self.dtype) + emb = self.time_embedding(t_emb) + + if self.config.num_class_embeds is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) + emb = emb + class_emb + + # 2. pre-process + sample = self.conv_in(sample) + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + down_block_res_samples += res_samples + + # 4. mid + sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states) + + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "attentions") and upsample_block.attentions is not None: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + upsample_size=upsample_size, + ) + else: + sample = upsample_block( + hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size + ) + # 6. post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if not return_dict: + return (sample,) + + return UNet2DConditionOutput(sample=sample) + + +class LinearMultiDim(nn.Linear): + def __init__(self, in_features, out_features=None, second_dim=4, *args, **kwargs): + in_features = [in_features, second_dim, 1] if isinstance(in_features, int) else list(in_features) + if out_features is None: + out_features = in_features + out_features = [out_features, second_dim, 1] if isinstance(out_features, int) else list(out_features) + self.in_features_multidim = in_features + self.out_features_multidim = out_features + super().__init__(np.array(in_features).prod(), np.array(out_features).prod()) + + def forward(self, input_tensor, *args, **kwargs): + shape = input_tensor.shape + n_dim = len(self.in_features_multidim) + input_tensor = input_tensor.reshape(*shape[0:-n_dim], self.in_features) + output_tensor = super().forward(input_tensor) + output_tensor = output_tensor.view(*shape[0:-n_dim], *self.out_features_multidim) + return output_tensor + + +class ResnetBlockFlat(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + dropout=0.0, + temb_channels=512, + groups=32, + groups_out=None, + pre_norm=True, + eps=1e-6, + time_embedding_norm="default", + use_in_shortcut=None, + second_dim=4, + **kwargs, + ): + super().__init__() + self.pre_norm = pre_norm + self.pre_norm = True + + in_channels = [in_channels, second_dim, 1] if isinstance(in_channels, int) else list(in_channels) + self.in_channels_prod = np.array(in_channels).prod() + self.channels_multidim = in_channels + + if out_channels is not None: + out_channels = [out_channels, second_dim, 1] if isinstance(out_channels, int) else list(out_channels) + out_channels_prod = np.array(out_channels).prod() + self.out_channels_multidim = out_channels + else: + out_channels_prod = self.in_channels_prod + self.out_channels_multidim = self.channels_multidim + self.time_embedding_norm = time_embedding_norm + + if groups_out is None: + groups_out = groups + + self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=self.in_channels_prod, eps=eps, affine=True) + self.conv1 = torch.nn.Conv2d(self.in_channels_prod, out_channels_prod, kernel_size=1, padding=0) + + if temb_channels is not None: + self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels_prod) + else: + self.time_emb_proj = None + + self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels_prod, eps=eps, affine=True) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels_prod, out_channels_prod, kernel_size=1, padding=0) + + self.nonlinearity = nn.SiLU() + + self.use_in_shortcut = ( + self.in_channels_prod != out_channels_prod if use_in_shortcut is None else use_in_shortcut + ) + + self.conv_shortcut = None + if self.use_in_shortcut: + self.conv_shortcut = torch.nn.Conv2d( + self.in_channels_prod, out_channels_prod, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, input_tensor, temb): + shape = input_tensor.shape + n_dim = len(self.channels_multidim) + input_tensor = input_tensor.reshape(*shape[0:-n_dim], self.in_channels_prod, 1, 1) + input_tensor = input_tensor.view(-1, self.in_channels_prod, 1, 1) + + hidden_states = input_tensor + + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.conv1(hidden_states) + + if temb is not None: + temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None] + hidden_states = hidden_states + temb + + hidden_states = self.norm2(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + input_tensor = self.conv_shortcut(input_tensor) + + output_tensor = input_tensor + hidden_states + + output_tensor = output_tensor.view(*shape[0:-n_dim], -1) + output_tensor = output_tensor.view(*shape[0:-n_dim], *self.out_channels_multidim) + + return output_tensor + + +# Copied from diffusers.models.unet_2d_blocks.DownBlock2D with DownBlock2D->DownBlockFlat, ResnetBlock2D->ResnetBlockFlat, Downsample2D->LinearMultiDim +class DownBlockFlat(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_downsample=True, + downsample_padding=1, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlockFlat( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + LinearMultiDim( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, temb=None): + output_states = () + + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + else: + hidden_states = resnet(hidden_states, temb) + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +# Copied from diffusers.models.unet_2d_blocks.CrossAttnDownBlock2D with CrossAttnDownBlock2D->CrossAttnDownBlockFlat, ResnetBlock2D->ResnetBlockFlat, Downsample2D->LinearMultiDim +class CrossAttnDownBlockFlat(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + attention_type="default", + output_scale_factor=1.0, + downsample_padding=1, + add_downsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + ): + super().__init__() + resnets = [] + attentions = [] + + self.attention_type = attention_type + self.attn_num_head_channels = attn_num_head_channels + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlockFlat( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + LinearMultiDim( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def set_attention_slice(self, slice_size): + head_dims = self.attn_num_head_channels + head_dims = [head_dims] if isinstance(head_dims, int) else head_dims + if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims): + raise ValueError( + f"Make sure slice_size {slice_size} is a common divisor of " + f"the number of heads used in cross_attention: {head_dims}" + ) + if slice_size is not None and slice_size > min(head_dims): + raise ValueError( + f"slice_size {slice_size} has to be smaller or equal to " + f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}" + ) + + for attn in self.attentions: + attn._set_attention_slice(slice_size) + + def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): + for attn in self.attentions: + attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) + + def forward(self, hidden_states, temb=None, encoder_hidden_states=None): + output_states = () + + for resnet, attn in zip(self.resnets, self.attentions): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +# Copied from diffusers.models.unet_2d_blocks.UpBlock2D with UpBlock2D->UpBlockFlat, ResnetBlock2D->ResnetBlockFlat, Upsample2D->LinearMultiDim +class UpBlockFlat(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_upsample=True, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlockFlat( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([LinearMultiDim(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + else: + hidden_states = resnet(hidden_states, temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +# Copied from diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D with CrossAttnUpBlock2D->CrossAttnUpBlockFlat, ResnetBlock2D->ResnetBlockFlat, Upsample2D->LinearMultiDim +class CrossAttnUpBlockFlat(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + attention_type="default", + output_scale_factor=1.0, + add_upsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + ): + super().__init__() + resnets = [] + attentions = [] + + self.attention_type = attention_type + self.attn_num_head_channels = attn_num_head_channels + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlockFlat( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([LinearMultiDim(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def set_attention_slice(self, slice_size): + head_dims = self.attn_num_head_channels + head_dims = [head_dims] if isinstance(head_dims, int) else head_dims + if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims): + raise ValueError( + f"Make sure slice_size {slice_size} is a common divisor of " + f"the number of heads used in cross_attention: {head_dims}" + ) + if slice_size is not None and slice_size > min(head_dims): + raise ValueError( + f"slice_size {slice_size} has to be smaller or equal to " + f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}" + ) + + for attn in self.attentions: + attn._set_attention_slice(slice_size) + + self.gradient_checkpointing = False + + def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): + for attn in self.attentions: + attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) + + def forward( + self, + hidden_states, + res_hidden_states_tuple, + temb=None, + encoder_hidden_states=None, + upsample_size=None, + ): + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +# Copied from diffusers.models.unet_2d_blocks.UNetMidBlock2DCrossAttn with UNetMidBlock2DCrossAttn->UNetMidBlockFlatCrossAttn, ResnetBlock2D->ResnetBlockFlat +class UNetMidBlockFlatCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + attention_type="default", + output_scale_factor=1.0, + cross_attention_dim=1280, + dual_cross_attention=False, + use_linear_projection=False, + ): + super().__init__() + + self.attention_type = attention_type + self.attn_num_head_channels = attn_num_head_channels + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + # there is always at least one resnet + resnets = [ + ResnetBlockFlat( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + + for _ in range(num_layers): + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + attn_num_head_channels, + in_channels // attn_num_head_channels, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + attn_num_head_channels, + in_channels // attn_num_head_channels, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + resnets.append( + ResnetBlockFlat( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def set_attention_slice(self, slice_size): + head_dims = self.attn_num_head_channels + head_dims = [head_dims] if isinstance(head_dims, int) else head_dims + if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims): + raise ValueError( + f"Make sure slice_size {slice_size} is a common divisor of " + f"the number of heads used in cross_attention: {head_dims}" + ) + if slice_size is not None and slice_size > min(head_dims): + raise ValueError( + f"slice_size {slice_size} has to be smaller or equal to " + f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}" + ) + + for attn in self.attentions: + attn._set_attention_slice(slice_size) + + def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): + for attn in self.attentions: + attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) + + def forward(self, hidden_states, temb=None, encoder_hidden_states=None): + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + hidden_states = attn(hidden_states, encoder_hidden_states).sample + hidden_states = resnet(hidden_states, temb) + + return hidden_states diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py new file mode 100644 index 000000000000..7be7f4d3aee6 --- /dev/null +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py @@ -0,0 +1,463 @@ +import inspect +from typing import Callable, List, Optional, Union + +import torch + +import PIL.Image +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModel + +from ...models import AutoencoderKL, UNet2DConditionModel +from ...pipeline_utils import DiffusionPipeline +from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from ...utils import logging +from .pipeline_versatile_diffusion_dual_guided import VersatileDiffusionDualGuidedPipeline +from .pipeline_versatile_diffusion_image_variation import VersatileDiffusionImageVariationPipeline +from .pipeline_versatile_diffusion_text_to_image import VersatileDiffusionTextToImagePipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class VersatileDiffusionPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionMegaSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + tokenizer: CLIPTokenizer + image_feature_extractor: CLIPFeatureExtractor + text_encoder: CLIPTextModel + image_encoder: CLIPVisionModel + image_unet: UNet2DConditionModel + text_unet: UNet2DConditionModel + vae: AutoencoderKL + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] + + def __init__( + self, + tokenizer: CLIPTokenizer, + image_feature_extractor: CLIPFeatureExtractor, + text_encoder: CLIPTextModel, + image_encoder: CLIPVisionModel, + image_unet: UNet2DConditionModel, + text_unet: UNet2DConditionModel, + vae: AutoencoderKL, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, + image_feature_extractor=image_feature_extractor, + text_encoder=text_encoder, + image_encoder=image_encoder, + image_unet=image_unet, + text_unet=text_unet, + vae=vae, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, + `attention_head_dim` must be a multiple of `slice_size`. + """ + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = self.image_unet.config.attention_head_dim // 2 + self.image_unet.set_attention_slice(slice_size) + self.text_unet.set_attention_slice(slice_size) + + def disable_attention_slicing(self): + r""" + Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go + back to computing attention in one step. + """ + # set slice_size = `None` to disable `attention slicing` + self.enable_attention_slicing(None) + + @torch.no_grad() + def image_variation( + self, + image: Union[torch.FloatTensor, PIL.Image.Image], + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + image (`PIL.Image.Image`, `List[PIL.Image.Image]` or `torch.Tensor`): + The image prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Examples: + + ```py + >>> from diffusers import VersatileDiffusionPipeline + >>> import torch + >>> import requests + >>> from io import BytesIO + >>> from PIL import Image + + >>> # let's download an initial image + >>> url = "https://huggingface.co/datasets/diffusers/images/resolve/main/benz.jpg" + + >>> response = requests.get(url) + >>> image = Image.open(BytesIO(response.content)).convert("RGB") + + >>> pipe = VersatileDiffusionPipeline.from_pretrained( + ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> generator = torch.Generator(device="cuda").manual_seed(0) + >>> image = pipe.image_variation(image, generator=generator).images[0] + >>> image.save("./car_variation.png") + ``` + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + expected_components = inspect.signature(VersatileDiffusionImageVariationPipeline.__init__).parameters.keys() + components = {name: component for name, component in self.components.items() if name in expected_components} + return VersatileDiffusionImageVariationPipeline(**components)( + image=image, + height=height, + width=width, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + eta=eta, + generator=generator, + latents=latents, + output_type=output_type, + return_dict=return_dict, + callback=callback, + callback_steps=callback_steps, + ) + + @torch.no_grad() + def text_to_image( + self, + prompt: Union[str, List[str]], + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Examples: + + ```py + >>> from diffusers import VersatileDiffusionPipeline + >>> import torch + + >>> pipe = VersatileDiffusionPipeline.from_pretrained( + ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> generator = torch.Generator(device="cuda").manual_seed(0) + >>> image = pipe.text_to_image("an astronaut riding on a horse on mars", generator=generator).images[0] + >>> image.save("./astronaut.png") + ``` + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + expected_components = inspect.signature(VersatileDiffusionTextToImagePipeline.__init__).parameters.keys() + components = {name: component for name, component in self.components.items() if name in expected_components} + temp_pipeline = VersatileDiffusionTextToImagePipeline(**components) + output = temp_pipeline( + prompt=prompt, + height=height, + width=width, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + eta=eta, + generator=generator, + latents=latents, + output_type=output_type, + return_dict=return_dict, + callback=callback, + callback_steps=callback_steps, + ) + # swap the attention blocks back to the original state + temp_pipeline._swap_unet_attention_blocks() + + return output + + @torch.no_grad() + def dual_guided( + self, + prompt: Union[PIL.Image.Image, List[PIL.Image.Image]], + image: Union[str, List[str]], + text_to_image_strength: float = 0.5, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Examples: + + ```py + >>> from diffusers import VersatileDiffusionPipeline + >>> import torch + >>> import requests + >>> from io import BytesIO + >>> from PIL import Image + + >>> # let's download an initial image + >>> url = "https://huggingface.co/datasets/diffusers/images/resolve/main/benz.jpg" + + >>> response = requests.get(url) + >>> image = Image.open(BytesIO(response.content)).convert("RGB") + >>> text = "a red car in the sun" + + >>> pipe = VersatileDiffusionPipeline.from_pretrained( + ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> generator = torch.Generator(device="cuda").manual_seed(0) + >>> text_to_image_strength = 0.75 + + >>> image = pipe.dual_guided( + ... prompt=text, image=image, text_to_image_strength=text_to_image_strength, generator=generator + ... ).images[0] + >>> image.save("./car_variation.png") + ``` + + Returns: + [`~pipelines.stable_diffusion.ImagePipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.ImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple. When + returning a tuple, the first element is a list with the generated images. + """ + + expected_components = inspect.signature(VersatileDiffusionDualGuidedPipeline.__init__).parameters.keys() + components = {name: component for name, component in self.components.items() if name in expected_components} + temp_pipeline = VersatileDiffusionDualGuidedPipeline(**components) + output = temp_pipeline( + prompt=prompt, + image=image, + text_to_image_strength=text_to_image_strength, + height=height, + width=width, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + num_images_per_prompt=num_images_per_prompt, + eta=eta, + generator=generator, + latents=latents, + output_type=output_type, + return_dict=return_dict, + callback=callback, + callback_steps=callback_steps, + ) + temp_pipeline._revert_dual_attention() + + return output diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py new file mode 100644 index 000000000000..fa1754a4f062 --- /dev/null +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py @@ -0,0 +1,641 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.utils.checkpoint + +import PIL +from transformers import ( + CLIPFeatureExtractor, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from ...models import AutoencoderKL, UNet2DConditionModel +from ...models.attention import DualTransformer2DModel, Transformer2DModel +from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from ...utils import is_accelerate_available, logging +from .modeling_text_unet import UNetFlatConditionModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline): + r""" + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Parameters: + vqvae ([`VQModel`]): + Vector-quantized (VQ) Model to encode and decode images to and from latent representations. + bert ([`LDMBertModel`]): + Text-encoder model based on [BERT](https://huggingface.co/docs/transformers/model_doc/bert) architecture. + tokenizer (`transformers.BertTokenizer`): + Tokenizer of class + [BertTokenizer](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + tokenizer: CLIPTokenizer + image_feature_extractor: CLIPFeatureExtractor + text_encoder: CLIPTextModelWithProjection + image_encoder: CLIPVisionModelWithProjection + image_unet: UNet2DConditionModel + text_unet: UNetFlatConditionModel + vae: AutoencoderKL + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] + + _optional_components = ["text_unet"] + + def __init__( + self, + tokenizer: CLIPTokenizer, + image_feature_extractor: CLIPFeatureExtractor, + text_encoder: CLIPTextModelWithProjection, + image_encoder: CLIPVisionModelWithProjection, + image_unet: UNet2DConditionModel, + text_unet: UNetFlatConditionModel, + vae: AutoencoderKL, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + ): + super().__init__() + self.register_modules( + tokenizer=tokenizer, + image_feature_extractor=image_feature_extractor, + text_encoder=text_encoder, + image_encoder=image_encoder, + image_unet=image_unet, + text_unet=text_unet, + vae=vae, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + + if self.text_unet is not None and ( + "dual_cross_attention" not in self.image_unet.config or not self.image_unet.config.dual_cross_attention + ): + # if loading from a universal checkpoint rather than a saved dual-guided pipeline + self._convert_to_dual_attention() + + def remove_unused_weights(self): + self.register_modules(text_unet=None) + + def _convert_to_dual_attention(self): + """ + Replace image_unet's `Transformer2DModel` blocks with `DualTransformer2DModel` that contains transformer blocks + from both `image_unet` and `text_unet` + """ + for name, module in self.image_unet.named_modules(): + if isinstance(module, Transformer2DModel): + parent_name, index = name.rsplit(".", 1) + index = int(index) + + image_transformer = self.image_unet.get_submodule(parent_name)[index] + text_transformer = self.text_unet.get_submodule(parent_name)[index] + + config = image_transformer.config + dual_transformer = DualTransformer2DModel( + num_attention_heads=config.num_attention_heads, + attention_head_dim=config.attention_head_dim, + in_channels=config.in_channels, + num_layers=config.num_layers, + dropout=config.dropout, + norm_num_groups=config.norm_num_groups, + cross_attention_dim=config.cross_attention_dim, + attention_bias=config.attention_bias, + sample_size=config.sample_size, + num_vector_embeds=config.num_vector_embeds, + activation_fn=config.activation_fn, + num_embeds_ada_norm=config.num_embeds_ada_norm, + ) + dual_transformer.transformers[0] = image_transformer + dual_transformer.transformers[1] = text_transformer + + self.image_unet.get_submodule(parent_name)[index] = dual_transformer + self.image_unet.register_to_config(dual_cross_attention=True) + + def _revert_dual_attention(self): + """ + Revert the image_unet `DualTransformer2DModel` blocks back to `Transformer2DModel` with image_unet weights Call + this function if you reuse `image_unet` in another pipeline, e.g. `VersatileDiffusionPipeline` + """ + for name, module in self.image_unet.named_modules(): + if isinstance(module, DualTransformer2DModel): + parent_name, index = name.rsplit(".", 1) + index = int(index) + self.image_unet.get_submodule(parent_name)[index] = module.transformers[0] + + self.image_unet.register_to_config(dual_cross_attention=False) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention with unet->image_unet + def enable_xformers_memory_efficient_attention(self): + r""" + Enable memory efficient attention as implemented in xformers. + + When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference + time. Speed up at training time is not guaranteed. + + Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention + is used. + """ + self.image_unet.set_use_memory_efficient_attention_xformers(True) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_xformers_memory_efficient_attention with unet->image_unet + def disable_xformers_memory_efficient_attention(self): + r""" + Disable memory efficient attention as implemented in xformers. + """ + self.image_unet.set_use_memory_efficient_attention_xformers(False) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing with unet->image_unet + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, + `attention_head_dim` must be a multiple of `slice_size`. + """ + if slice_size == "auto": + if isinstance(self.image_unet.config.attention_head_dim, int): + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = self.image_unet.config.attention_head_dim // 2 + else: + # if `attention_head_dim` is a list, take the smallest head size + slice_size = min(self.image_unet.config.attention_head_dim) + + self.image_unet.set_attention_slice(slice_size) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing + def disable_attention_slicing(self): + r""" + Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go + back to computing attention in one step. + """ + # set slice_size = `None` to disable `attention slicing` + self.enable_attention_slicing(None) + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + for cpu_offloaded_model in [self.image_unet, self.text_unet, self.text_encoder, self.vae]: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device with unet->image_unet + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if self.device != torch.device("meta") or not hasattr(self.image_unet, "_hf_hook"): + return self.device + for module in self.image_unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + def _encode_text_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list(int)`): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + """ + + def normalize_embeddings(encoder_output): + embeds = self.text_encoder.text_projection(encoder_output.last_hidden_state) + embeds_pooled = encoder_output.text_embeds + embeds = embeds / torch.norm(embeds_pooled.unsqueeze(1), dim=-1, keepdim=True) + return embeds + + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids + + if not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + text_embeddings = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + text_embeddings = normalize_embeddings(text_embeddings) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = text_embeddings.shape + text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) + text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_tokens = [""] * batch_size + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + uncond_embeddings = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + uncond_embeddings = normalize_embeddings(uncond_embeddings) + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = uncond_embeddings.shape[1] + uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) + uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + return text_embeddings + + def _encode_image_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list(int)`): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + """ + + def normalize_embeddings(encoder_output): + embeds = self.image_encoder.vision_model.post_layernorm(encoder_output.last_hidden_state) + embeds = self.image_encoder.visual_projection(embeds) + embeds_pooled = embeds[:, 0:1] + embeds = embeds / torch.norm(embeds_pooled, dim=-1, keepdim=True) + return embeds + + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + # get prompt text embeddings + image_input = self.image_feature_extractor(images=prompt, return_tensors="pt") + pixel_values = image_input.pixel_values.to(device).to(self.image_encoder.dtype) + image_embeddings = self.image_encoder(pixel_values) + image_embeddings = normalize_embeddings(image_embeddings) + + # duplicate image embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = image_embeddings.shape + image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1) + image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_images = [np.zeros((512, 512, 3)) + 0.5] * batch_size + uncond_images = self.image_feature_extractor(images=uncond_images, return_tensors="pt") + pixel_values = uncond_images.pixel_values.to(device).to(self.image_encoder.dtype) + uncond_embeddings = self.image_encoder(pixel_values) + uncond_embeddings = normalize_embeddings(uncond_embeddings) + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = uncond_embeddings.shape[1] + uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) + uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and conditional embeddings into a single batch + # to avoid doing two forward passes + image_embeddings = torch.cat([uncond_embeddings, image_embeddings]) + + return image_embeddings + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + latents = 1 / 0.18215 * latents + image = self.vae.decode(latents).sample + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs(self, prompt, image, height, width, callback_steps): + if not isinstance(prompt, str) and not isinstance(prompt, PIL.Image.Image) and not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` `PIL.Image` or `list` but is {type(prompt)}") + if not isinstance(image, str) and not isinstance(image, PIL.Image.Image) and not isinstance(image, list): + raise ValueError(f"`image` has to be of type `str` `PIL.Image` or `list` but is {type(image)}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if latents is None: + if device.type == "mps": + # randn does not work reproducibly on mps + latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device) + else: + latents = torch.randn(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def set_transformer_params(self, mix_ratio: float = 0.5, condition_types: Tuple = ("text", "image")): + for name, module in self.image_unet.named_modules(): + if isinstance(module, DualTransformer2DModel): + module.mix_ratio = mix_ratio + + for i, type in enumerate(condition_types): + if type == "text": + module.condition_lengths[i] = self.text_encoder.config.max_position_embeddings + module.transformer_index_for_condition[i] = 1 # use the second (text) transformer + else: + module.condition_lengths[i] = 257 + module.transformer_index_for_condition[i] = 0 # use the first (image) transformer + + @torch.no_grad() + def __call__( + self, + prompt: Union[PIL.Image.Image, List[PIL.Image.Image]], + image: Union[str, List[str]], + text_to_image_strength: float = 0.5, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Examples: + + ```py + >>> from diffusers import VersatileDiffusionDualGuidedPipeline + >>> import torch + >>> import requests + >>> from io import BytesIO + >>> from PIL import Image + + >>> # let's download an initial image + >>> url = "https://huggingface.co/datasets/diffusers/images/resolve/main/benz.jpg" + + >>> response = requests.get(url) + >>> image = Image.open(BytesIO(response.content)).convert("RGB") + >>> text = "a red car in the sun" + + >>> pipe = VersatileDiffusionDualGuidedPipeline.from_pretrained( + ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16 + ... ) + >>> pipe.remove_unused_weights() + >>> pipe = pipe.to("cuda") + + >>> generator = torch.Generator(device="cuda").manual_seed(0) + >>> text_to_image_strength = 0.75 + + >>> image = pipe( + ... prompt=text, image=image, text_to_image_strength=text_to_image_strength, generator=generator + ... ).images[0] + >>> image.save("./car_variation.png") + ``` + + Returns: + [`~pipelines.stable_diffusion.ImagePipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.ImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple. When + returning a tuple, the first element is a list with the generated images. + """ + # 0. Default height and width to unet + height = height or self.image_unet.config.sample_size * self.vae_scale_factor + width = width or self.image_unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, image, height, width, callback_steps) + + # 2. Define call parameters + prompt = [prompt] if not isinstance(prompt, list) else prompt + image = [image] if not isinstance(image, list) else image + batch_size = len(prompt) + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompts + text_embeddings = self._encode_text_prompt(prompt, device, num_images_per_prompt, do_classifier_free_guidance) + image_embeddings = self._encode_image_prompt(image, device, num_images_per_prompt, do_classifier_free_guidance) + dual_prompt_embeddings = torch.cat([text_embeddings, image_embeddings], dim=1) + prompt_types = ("text", "image") + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.image_unet.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + dual_prompt_embeddings.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Combine the attention blocks of the image and text UNets + self.set_transformer_params(text_to_image_strength, prompt_types) + + # 8. Denoising loop + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.image_unet(latent_model_input, t, encoder_hidden_states=dual_prompt_embeddings).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # 9. Post-processing + image = self.decode_latents(latents) + + # 10. Convert to PIL + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py new file mode 100644 index 000000000000..3e51ce6371f0 --- /dev/null +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py @@ -0,0 +1,471 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable, List, Optional, Union + +import numpy as np +import torch +import torch.utils.checkpoint + +import PIL +from transformers import CLIPFeatureExtractor, CLIPVisionModelWithProjection + +from ...models import AutoencoderKL, UNet2DConditionModel +from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from ...utils import is_accelerate_available, logging + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class VersatileDiffusionImageVariationPipeline(DiffusionPipeline): + r""" + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Parameters: + vqvae ([`VQModel`]): + Vector-quantized (VQ) Model to encode and decode images to and from latent representations. + bert ([`LDMBertModel`]): + Text-encoder model based on [BERT](https://huggingface.co/docs/transformers/model_doc/bert) architecture. + tokenizer (`transformers.BertTokenizer`): + Tokenizer of class + [BertTokenizer](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + image_feature_extractor: CLIPFeatureExtractor + image_encoder: CLIPVisionModelWithProjection + image_unet: UNet2DConditionModel + vae: AutoencoderKL + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] + + def __init__( + self, + image_feature_extractor: CLIPFeatureExtractor, + image_encoder: CLIPVisionModelWithProjection, + image_unet: UNet2DConditionModel, + vae: AutoencoderKL, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + ): + super().__init__() + self.register_modules( + image_feature_extractor=image_feature_extractor, + image_encoder=image_encoder, + image_unet=image_unet, + vae=vae, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention with unet->image_unet + def enable_xformers_memory_efficient_attention(self): + r""" + Enable memory efficient attention as implemented in xformers. + + When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference + time. Speed up at training time is not guaranteed. + + Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention + is used. + """ + self.image_unet.set_use_memory_efficient_attention_xformers(True) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_xformers_memory_efficient_attention with unet->image_unet + def disable_xformers_memory_efficient_attention(self): + r""" + Disable memory efficient attention as implemented in xformers. + """ + self.image_unet.set_use_memory_efficient_attention_xformers(False) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing with unet->image_unet + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, + `attention_head_dim` must be a multiple of `slice_size`. + """ + if slice_size == "auto": + if isinstance(self.image_unet.config.attention_head_dim, int): + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = self.image_unet.config.attention_head_dim // 2 + else: + # if `attention_head_dim` is a list, take the smallest head size + slice_size = min(self.image_unet.config.attention_head_dim) + + self.image_unet.set_attention_slice(slice_size) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing + def disable_attention_slicing(self): + r""" + Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go + back to computing attention in one step. + """ + # set slice_size = `None` to disable `attention slicing` + self.enable_attention_slicing(None) + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + for cpu_offloaded_model in [self.image_unet, self.text_unet, self.text_encoder, self.vae]: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device with unet->image_unet + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if self.device != torch.device("meta") or not hasattr(self.image_unet, "_hf_hook"): + return self.device + for module in self.image_unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list(int)`): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + """ + + def normalize_embeddings(encoder_output): + embeds = self.image_encoder.vision_model.post_layernorm(encoder_output.last_hidden_state) + embeds = self.image_encoder.visual_projection(embeds) + embeds_pooled = embeds[:, 0:1] + embeds = embeds / torch.norm(embeds_pooled, dim=-1, keepdim=True) + return embeds + + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + # get prompt text embeddings + image_input = self.image_feature_extractor(images=prompt, return_tensors="pt") + pixel_values = image_input.pixel_values.to(device).to(self.image_encoder.dtype) + image_embeddings = self.image_encoder(pixel_values) + image_embeddings = normalize_embeddings(image_embeddings) + + # duplicate image embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = image_embeddings.shape + image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1) + image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_images: List[str] + if negative_prompt is None: + uncond_images = [np.zeros((512, 512, 3)) + 0.5] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, PIL.Image.Image): + uncond_images = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_images = negative_prompt + + uncond_images = self.image_feature_extractor(images=uncond_images, return_tensors="pt") + pixel_values = uncond_images.pixel_values.to(device).to(self.image_encoder.dtype) + uncond_embeddings = self.image_encoder(pixel_values) + uncond_embeddings = normalize_embeddings(uncond_embeddings) + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = uncond_embeddings.shape[1] + uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) + uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and conditional embeddings into a single batch + # to avoid doing two forward passes + image_embeddings = torch.cat([uncond_embeddings, image_embeddings]) + + return image_embeddings + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + latents = 1 / 0.18215 * latents + image = self.vae.decode(latents).sample + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs(self, image, height, width, callback_steps): + if not isinstance(image, PIL.Image.Image) and not isinstance(image, torch.Tensor): + raise ValueError(f"`image` has to be of type `PIL.Image.Image` or `torch.Tensor` but is {type(image)}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if latents is None: + if device.type == "mps": + # randn does not work reproducibly on mps + latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device) + else: + latents = torch.randn(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + def __call__( + self, + image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.Tensor], + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + image (`PIL.Image.Image`, `List[PIL.Image.Image]` or `torch.Tensor`): + The image prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Examples: + + ```py + >>> from diffusers import VersatileDiffusionImageVariationPipeline + >>> import torch + >>> import requests + >>> from io import BytesIO + >>> from PIL import Image + + >>> # let's download an initial image + >>> url = "https://huggingface.co/datasets/diffusers/images/resolve/main/benz.jpg" + + >>> response = requests.get(url) + >>> image = Image.open(BytesIO(response.content)).convert("RGB") + + >>> pipe = VersatileDiffusionImageVariationPipeline.from_pretrained( + ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> generator = torch.Generator(device="cuda").manual_seed(0) + >>> image = pipe(image, generator=generator).images[0] + >>> image.save("./car_variation.png") + ``` + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet + height = height or self.image_unet.config.sample_size * self.vae_scale_factor + width = width or self.image_unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs(image, height, width, callback_steps) + + # 2. Define call parameters + batch_size = 1 if isinstance(image, PIL.Image.Image) else len(image) + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + image_embeddings = self._encode_prompt( + image, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.image_unet.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + image_embeddings.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.image_unet(latent_model_input, t, encoder_hidden_states=image_embeddings).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # 8. Post-processing + image = self.decode_latents(latents) + + # 9. Convert to PIL + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py new file mode 100644 index 000000000000..e77f5a2f22e4 --- /dev/null +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py @@ -0,0 +1,525 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable, List, Optional, Union + +import torch +import torch.utils.checkpoint + +from transformers import CLIPFeatureExtractor, CLIPTextModelWithProjection, CLIPTokenizer + +from ...models import AutoencoderKL, UNet2DConditionModel +from ...models.attention import Transformer2DModel +from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from ...utils import is_accelerate_available, logging +from .modeling_text_unet import UNetFlatConditionModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class VersatileDiffusionTextToImagePipeline(DiffusionPipeline): + r""" + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Parameters: + vqvae ([`VQModel`]): + Vector-quantized (VQ) Model to encode and decode images to and from latent representations. + bert ([`LDMBertModel`]): + Text-encoder model based on [BERT](https://huggingface.co/docs/transformers/model_doc/bert) architecture. + tokenizer (`transformers.BertTokenizer`): + Tokenizer of class + [BertTokenizer](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + tokenizer: CLIPTokenizer + image_feature_extractor: CLIPFeatureExtractor + text_encoder: CLIPTextModelWithProjection + image_unet: UNet2DConditionModel + text_unet: UNetFlatConditionModel + vae: AutoencoderKL + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] + + _optional_components = ["text_unet"] + + def __init__( + self, + tokenizer: CLIPTokenizer, + text_encoder: CLIPTextModelWithProjection, + image_unet: UNet2DConditionModel, + text_unet: UNetFlatConditionModel, + vae: AutoencoderKL, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + ): + super().__init__() + self.register_modules( + tokenizer=tokenizer, + text_encoder=text_encoder, + image_unet=image_unet, + text_unet=text_unet, + vae=vae, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + + if self.text_unet is not None: + self._swap_unet_attention_blocks() + + def _swap_unet_attention_blocks(self): + """ + Swap the `Transformer2DModel` blocks between the image and text UNets + """ + for name, module in self.image_unet.named_modules(): + if isinstance(module, Transformer2DModel): + parent_name, index = name.rsplit(".", 1) + index = int(index) + self.image_unet.get_submodule(parent_name)[index], self.text_unet.get_submodule(parent_name)[index] = ( + self.text_unet.get_submodule(parent_name)[index], + self.image_unet.get_submodule(parent_name)[index], + ) + + def remove_unused_weights(self): + self.register_modules(text_unet=None) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention with unet->image_unet + def enable_xformers_memory_efficient_attention(self): + r""" + Enable memory efficient attention as implemented in xformers. + + When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference + time. Speed up at training time is not guaranteed. + + Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention + is used. + """ + self.image_unet.set_use_memory_efficient_attention_xformers(True) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_xformers_memory_efficient_attention with unet->image_unet + def disable_xformers_memory_efficient_attention(self): + r""" + Disable memory efficient attention as implemented in xformers. + """ + self.image_unet.set_use_memory_efficient_attention_xformers(False) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing with unet->image_unet + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, + `attention_head_dim` must be a multiple of `slice_size`. + """ + if slice_size == "auto": + if isinstance(self.image_unet.config.attention_head_dim, int): + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = self.image_unet.config.attention_head_dim // 2 + else: + # if `attention_head_dim` is a list, take the smallest head size + slice_size = min(self.image_unet.config.attention_head_dim) + + self.image_unet.set_attention_slice(slice_size) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing + def disable_attention_slicing(self): + r""" + Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go + back to computing attention in one step. + """ + # set slice_size = `None` to disable `attention slicing` + self.enable_attention_slicing(None) + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + for cpu_offloaded_model in [self.image_unet, self.text_unet, self.text_encoder, self.vae]: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device with unet->image_unet + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if self.device != torch.device("meta") or not hasattr(self.image_unet, "_hf_hook"): + return self.device + for module in self.image_unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list(int)`): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + """ + + def normalize_embeddings(encoder_output): + embeds = self.text_encoder.text_projection(encoder_output.last_hidden_state) + embeds_pooled = encoder_output.text_embeds + embeds = embeds / torch.norm(embeds_pooled.unsqueeze(1), dim=-1, keepdim=True) + return embeds + + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids + + if not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + text_embeddings = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + text_embeddings = normalize_embeddings(text_embeddings) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = text_embeddings.shape + text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) + text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + uncond_embeddings = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + uncond_embeddings = normalize_embeddings(uncond_embeddings) + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = uncond_embeddings.shape[1] + uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) + uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + return text_embeddings + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + latents = 1 / 0.18215 * latents + image = self.vae.decode(latents).sample + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs + def check_inputs(self, prompt, height, width, callback_steps): + if not isinstance(prompt, str) and not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if latents is None: + if device.type == "mps": + # randn does not work reproducibly on mps + latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device) + else: + latents = torch.randn(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Examples: + + ```py + >>> from diffusers import VersatileDiffusionTextToImagePipeline + >>> import torch + + >>> pipe = VersatileDiffusionTextToImagePipeline.from_pretrained( + ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16 + ... ) + >>> pipe.remove_unused_weights() + >>> pipe = pipe.to("cuda") + + >>> generator = torch.Generator(device="cuda").manual_seed(0) + >>> image = pipe("an astronaut riding on a horse on mars", generator=generator).images[0] + >>> image.save("./astronaut.png") + ``` + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet + height = height or self.image_unet.config.sample_size * self.vae_scale_factor + width = width or self.image_unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, height, width, callback_steps) + + # 2. Define call parameters + batch_size = 1 if isinstance(prompt, str) else len(prompt) + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_embeddings = self._encode_prompt( + prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.image_unet.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + text_embeddings.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.image_unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # 9. Post-processing + image = self.decode_latents(latents) + + # 10. Convert to PIL + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/src/diffusers/pipelines/vq_diffusion/__init__.py b/src/diffusers/pipelines/vq_diffusion/__init__.py new file mode 100644 index 000000000000..8c9f14f00064 --- /dev/null +++ b/src/diffusers/pipelines/vq_diffusion/__init__.py @@ -0,0 +1,5 @@ +from ...utils import is_torch_available, is_transformers_available + + +if is_transformers_available() and is_torch_available(): + from .pipeline_vq_diffusion import LearnedClassifierFreeSamplingEmbeddings, VQDiffusionPipeline diff --git a/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py b/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py new file mode 100644 index 000000000000..333599d7ecf8 --- /dev/null +++ b/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py @@ -0,0 +1,335 @@ +# Copyright 2022 Microsoft and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, List, Optional, Tuple, Union + +import torch + +from diffusers import Transformer2DModel, VQModel +from diffusers.schedulers.scheduling_vq_diffusion import VQDiffusionScheduler +from transformers import CLIPTextModel, CLIPTokenizer + +from ...configuration_utils import ConfigMixin, register_to_config +from ...modeling_utils import ModelMixin +from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from ...utils import logging + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class LearnedClassifierFreeSamplingEmbeddings(ModelMixin, ConfigMixin): + """ + Utility class for storing learned text embeddings for classifier free sampling + """ + + @register_to_config + def __init__(self, learnable: bool, hidden_size: Optional[int] = None, length: Optional[int] = None): + super().__init__() + + self.learnable = learnable + + if self.learnable: + assert hidden_size is not None, "learnable=True requires `hidden_size` to be set" + assert length is not None, "learnable=True requires `length` to be set" + + embeddings = torch.zeros(length, hidden_size) + else: + embeddings = None + + self.embeddings = torch.nn.Parameter(embeddings) + + +class VQDiffusionPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using VQ Diffusion + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vqvae ([`VQModel`]): + Vector Quantized Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent + representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. VQ Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + transformer ([`Transformer2DModel`]): + Conditional transformer to denoise the encoded image latents. + scheduler ([`VQDiffusionScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + """ + + vqvae: VQModel + text_encoder: CLIPTextModel + tokenizer: CLIPTokenizer + transformer: Transformer2DModel + learned_classifier_free_sampling_embeddings: LearnedClassifierFreeSamplingEmbeddings + scheduler: VQDiffusionScheduler + + def __init__( + self, + vqvae: VQModel, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + transformer: Transformer2DModel, + scheduler: VQDiffusionScheduler, + learned_classifier_free_sampling_embeddings: LearnedClassifierFreeSamplingEmbeddings, + ): + super().__init__() + + self.register_modules( + vqvae=vqvae, + transformer=transformer, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=scheduler, + learned_classifier_free_sampling_embeddings=learned_classifier_free_sampling_embeddings, + ) + + def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance): + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + + if text_input_ids.shape[-1] > self.tokenizer.model_max_length: + removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] + text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] + + # NOTE: This additional step of normalizing the text embeddings is from VQ-Diffusion. + # While CLIP does normalize the pooled output of the text transformer when combining + # the image and text embeddings, CLIP does not directly normalize the last hidden state. + # + # CLIP normalizing the pooled output. + # https://github.com/huggingface/transformers/blob/d92e22d1f28324f513f3080e5c47c071a3916721/src/transformers/models/clip/modeling_clip.py#L1052-L1053 + text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True) + + # duplicate text embeddings for each generation per prompt + text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0) + + if do_classifier_free_guidance: + if self.learned_classifier_free_sampling_embeddings.learnable: + uncond_embeddings = self.learned_classifier_free_sampling_embeddings.embeddings + uncond_embeddings = uncond_embeddings.unsqueeze(0).repeat(batch_size, 1, 1) + else: + uncond_tokens = [""] * batch_size + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] + # See comment for normalizing text embeddings + uncond_embeddings = uncond_embeddings / uncond_embeddings.norm(dim=-1, keepdim=True) + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = uncond_embeddings.shape[1] + uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) + uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + return text_embeddings + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + num_inference_steps: int = 100, + guidance_scale: float = 5.0, + truncation_rate: float = 1.0, + num_images_per_prompt: int = 1, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + ) -> Union[ImagePipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + truncation_rate (`float`, *optional*, defaults to 1.0 (equivalent to no truncation)): + Used to "truncate" the predicted classes for x_0 such that the cumulative probability for a pixel is at + most `truncation_rate`. The lowest probabilities that would increase the cumulative probability above + `truncation_rate` are set to zero. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`torch.FloatTensor` of shape (batch), *optional*): + Pre-generated noisy latents to be used as inputs for image generation. Must be valid embedding indices. + Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will + be generated of completely masked latent pixels. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Returns: + [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~ pipeline_utils.ImagePipelineOutput `] if + `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the + generated images. + """ + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + batch_size = batch_size * num_images_per_prompt + + do_classifier_free_guidance = guidance_scale > 1.0 + + text_embeddings = self._encode_prompt(prompt, num_images_per_prompt, do_classifier_free_guidance) + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + # get the initial completely masked latents unless the user supplied it + + latents_shape = (batch_size, self.transformer.num_latent_pixels) + if latents is None: + mask_class = self.transformer.num_vector_embeds - 1 + latents = torch.full(latents_shape, mask_class).to(self.device) + else: + if latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + if (latents < 0).any() or (latents >= self.transformer.num_vector_embeds).any(): + raise ValueError( + "Unexpected latents value(s). All latents be valid embedding indices i.e. in the range 0," + f" {self.transformer.num_vector_embeds - 1} (inclusive)." + ) + latents = latents.to(self.device) + + # set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=self.device) + + timesteps_tensor = self.scheduler.timesteps.to(self.device) + + sample = latents + + for i, t in enumerate(self.progress_bar(timesteps_tensor)): + # expand the sample if we are doing classifier free guidance + latent_model_input = torch.cat([sample] * 2) if do_classifier_free_guidance else sample + + # predict the un-noised image + # model_output == `log_p_x_0` + model_output = self.transformer( + latent_model_input, encoder_hidden_states=text_embeddings, timestep=t + ).sample + + if do_classifier_free_guidance: + model_output_uncond, model_output_text = model_output.chunk(2) + model_output = model_output_uncond + guidance_scale * (model_output_text - model_output_uncond) + model_output -= torch.logsumexp(model_output, dim=1, keepdim=True) + + model_output = self.truncate(model_output, truncation_rate) + + # remove `log(0)`'s (`-inf`s) + model_output = model_output.clamp(-70) + + # compute the previous noisy sample x_t -> x_t-1 + sample = self.scheduler.step(model_output, timestep=t, sample=sample, generator=generator).prev_sample + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + callback(i, t, sample) + + embedding_channels = self.vqvae.config.vq_embed_dim + embeddings_shape = (batch_size, self.transformer.height, self.transformer.width, embedding_channels) + embeddings = self.vqvae.quantize.get_codebook_entry(sample, shape=embeddings_shape) + image = self.vqvae.decode(embeddings, force_not_quantize=True).sample + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) + + def truncate(self, log_p_x_0: torch.FloatTensor, truncation_rate: float) -> torch.FloatTensor: + """ + Truncates log_p_x_0 such that for each column vector, the total cumulative probability is `truncation_rate` The + lowest probabilities that would increase the cumulative probability above `truncation_rate` are set to zero. + """ + sorted_log_p_x_0, indices = torch.sort(log_p_x_0, 1, descending=True) + sorted_p_x_0 = torch.exp(sorted_log_p_x_0) + keep_mask = sorted_p_x_0.cumsum(dim=1) < truncation_rate + + # Ensure that at least the largest probability is not zeroed out + all_true = torch.full_like(keep_mask[:, 0:1, :], True) + keep_mask = torch.cat((all_true, keep_mask), dim=1) + keep_mask = keep_mask[:, :-1, :] + + keep_mask = keep_mask.gather(1, indices.argsort(1)) + + rv = log_p_x_0.clone() + + rv[~keep_mask] = -torch.inf # -inf = log(0) + + return rv diff --git a/src/diffusers/schedulers/README.md b/src/diffusers/schedulers/README.md index 6a01c503a909..9494e357fd43 100644 --- a/src/diffusers/schedulers/README.md +++ b/src/diffusers/schedulers/README.md @@ -1,17 +1,3 @@ # Schedulers -- Schedulers are the algorithms to use diffusion models in inference as well as for training. They include the noise schedules and define algorithm-specific diffusion steps. -- Schedulers can be used interchangeable between diffusion models in inference to find the preferred trade-off between speed and generation quality. -- Schedulers are available in PyTorch and Jax. - -## API - -- Schedulers should provide one or more `def step(...)` functions that should be called iteratively to unroll the diffusion loop during -the forward pass. -- Schedulers should be framework specific. - -## Examples - -- The DDPM scheduler was proposed in [Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) and can be found in [scheduling_ddpm.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_ddpm.py). An example of how to use this scheduler can be found in [pipeline_ddpm.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/ddpm/pipeline_ddpm.py). -- The DDIM scheduler was proposed in [Denoising Diffusion Implicit Models](https://arxiv.org/abs/2010.02502) and can be found in [scheduling_ddim.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_ddim.py). An example of how to use this scheduler can be found in [pipeline_ddim.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/ddim/pipeline_ddim.py). -- The PNDM scheduler was proposed in [Pseudo Numerical Methods for Diffusion Models on Manifolds](https://arxiv.org/abs/2202.09778) and can be found in [scheduling_pndm.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_pndm.py). An example of how to use this scheduler can be found in [pipeline_pndm.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pndm/pipeline_pndm.py). +For more information on the schedulers, please refer to the [docs](https://huggingface.co/docs/diffusers/api/schedulers). \ No newline at end of file diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index a906c39eb24c..6217bfcd6985 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -19,22 +19,29 @@ if is_torch_available(): from .scheduling_ddim import DDIMScheduler from .scheduling_ddpm import DDPMScheduler + from .scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler + from .scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler + from .scheduling_euler_discrete import EulerDiscreteScheduler + from .scheduling_ipndm import IPNDMScheduler from .scheduling_karras_ve import KarrasVeScheduler from .scheduling_pndm import PNDMScheduler + from .scheduling_repaint import RePaintScheduler from .scheduling_sde_ve import ScoreSdeVeScheduler from .scheduling_sde_vp import ScoreSdeVpScheduler from .scheduling_utils import SchedulerMixin + from .scheduling_vq_diffusion import VQDiffusionScheduler else: from ..utils.dummy_pt_objects import * # noqa F403 if is_flax_available(): from .scheduling_ddim_flax import FlaxDDIMScheduler from .scheduling_ddpm_flax import FlaxDDPMScheduler + from .scheduling_dpmsolver_multistep_flax import FlaxDPMSolverMultistepScheduler from .scheduling_karras_ve_flax import FlaxKarrasVeScheduler from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler from .scheduling_pndm_flax import FlaxPNDMScheduler from .scheduling_sde_ve_flax import FlaxScoreSdeVeScheduler - from .scheduling_utils_flax import FlaxSchedulerMixin + from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left else: from ..utils.dummy_flax_objects import * # noqa F403 diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 0f1a40229475..3640b37546ec 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -23,11 +23,12 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput, deprecate +from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, deprecate from .scheduling_utils import SchedulerMixin @dataclass +# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM class DDIMSchedulerOutput(BaseOutput): """ Output class for the scheduler's step function output. @@ -81,8 +82,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and - [`~ConfigMixin.from_config`] functions. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. For more details, see the original paper: https://arxiv.org/abs/2010.02502 @@ -105,9 +106,15 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): an offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in stable diffusion. + prediction_type (`str`, default `epsilon`): + indicates whether the model predicts the noise (epsilon), or the samples. One of `epsilon`, `sample`. + `v-prediction` is not supported for this scheduler. """ + _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + _deprecated_kwargs = ["predict_epsilon"] + @register_to_config def __init__( self, @@ -119,7 +126,17 @@ def __init__( clip_sample: bool = True, set_alpha_to_one: bool = True, steps_offset: int = 0, + prediction_type: str = "epsilon", + **kwargs, ): + message = ( + "Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler =" + " DDIMScheduler.from_pretrained(, prediction_type='epsilon')`." + ) + predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs) + if predict_epsilon is not None: + self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample") + if trained_betas is not None: self.betas = torch.from_numpy(trained_betas) elif beta_schedule == "linear": @@ -149,7 +166,7 @@ def __init__( # setable values self.num_inference_steps = None - self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy()) + self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: """ @@ -175,7 +192,7 @@ def _get_variance(self, timestep, prev_timestep): return variance - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, **kwargs): + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. @@ -183,18 +200,13 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. """ - deprecated_offset = deprecate( - "offset", "0.7.0", "Please pass `steps_offset` to `__init__` instead.", take_from=kwargs - ) - offset = deprecated_offset or self.config.steps_offset - self.num_inference_steps = num_inference_steps step_ratio = self.config.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy() + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) self.timesteps = torch.from_numpy(timesteps).to(device) - self.timesteps += offset + self.timesteps += self.config.steps_offset def step( self, @@ -204,6 +216,7 @@ def step( eta: float = 0.0, use_clipped_model_output: bool = False, generator=None, + variance_noise: Optional[torch.FloatTensor] = None, return_dict: bool = True, ) -> Union[DDIMSchedulerOutput, Tuple]: """ @@ -216,8 +229,14 @@ def step( sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. eta (`float`): weight of noise for added noise in diffusion step. - use_clipped_model_output (`bool`): TODO + use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped + predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when + `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would + coincide with the one provided as input and `use_clipped_model_output` will have not effect. generator: random number generator. + variance_noise (`torch.FloatTensor`): instead of generating noise for the variance using `generator`, we + can directly provide the noise for the variance itself. This is useful for methods such as + CycleDiffusion. (https://arxiv.org/abs/2210.05559) return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class Returns: @@ -253,7 +272,19 @@ def step( # 3. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + if self.config.prediction_type == "epsilon": + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + elif self.config.prediction_type == "sample": + pred_original_sample = model_output + elif self.config.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + # predict V + model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction`" + ) # 4. Clip "predicted x_0" if self.config.clip_sample: @@ -276,9 +307,23 @@ def step( if eta > 0: # randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072 - device = model_output.device if torch.is_tensor(model_output) else "cpu" - noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator).to(device) - variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * noise + device = model_output.device + if variance_noise is not None and generator is not None: + raise ValueError( + "Cannot pass both generator and variance_noise. Please make sure that either `generator` or" + " `variance_noise` stays `None`." + ) + + if variance_noise is None: + if device.type == "mps": + # randn does not work reproducibly on mps + variance_noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator) + variance_noise = variance_noise.to(device) + else: + variance_noise = torch.randn( + model_output.shape, generator=generator, device=device, dtype=model_output.dtype + ) + variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * variance_noise prev_sample = prev_sample + variance diff --git a/src/diffusers/schedulers/scheduling_ddim_flax.py b/src/diffusers/schedulers/scheduling_ddim_flax.py index 7726ef2eda0d..f98d9770043f 100644 --- a/src/diffusers/schedulers/scheduling_ddim_flax.py +++ b/src/diffusers/schedulers/scheduling_ddim_flax.py @@ -23,7 +23,13 @@ import jax.numpy as jnp from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput +from ..utils import deprecate +from .scheduling_utils_flax import ( + _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, + FlaxSchedulerMixin, + FlaxSchedulerOutput, + broadcast_to_shape_from_left, +) def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray: @@ -79,8 +85,8 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin): [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and - [`~ConfigMixin.from_config`] functions. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. For more details, see the original paper: https://arxiv.org/abs/2010.02502 @@ -103,8 +109,15 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin): an offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in stable diffusion. + prediction_type (`str`, default `epsilon`): + indicates whether the model predicts the noise (epsilon), or the samples. One of `epsilon`, `sample`. + `v-prediction` is not supported for this scheduler. + """ + _compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + _deprecated_kwargs = ["predict_epsilon"] + @property def has_state(self): return True @@ -118,7 +131,17 @@ def __init__( beta_schedule: str = "linear", set_alpha_to_one: bool = True, steps_offset: int = 0, + prediction_type: str = "epsilon", + **kwargs, ): + message = ( + "Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler =" + " FlaxDDIMScheduler.from_pretrained(, prediction_type='epsilon')`." + ) + predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs) + if predict_epsilon is not None: + self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample") + if beta_schedule == "linear": self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32) elif beta_schedule == "scaled_linear": @@ -173,7 +196,9 @@ def _get_variance(self, timestep, prev_timestep, alphas_cumprod): return variance - def set_timesteps(self, state: DDIMSchedulerState, num_inference_steps: int, shape: Tuple) -> DDIMSchedulerState: + def set_timesteps( + self, state: DDIMSchedulerState, num_inference_steps: int, shape: Tuple = () + ) -> DDIMSchedulerState: """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. @@ -211,9 +236,6 @@ def step( timestep (`int`): current discrete timestep in the diffusion chain. sample (`jnp.ndarray`): current instance of sample being created by diffusion process. - key (`random.KeyArray`): a PRNG key. - eta (`float`): weight of noise for added noise in diffusion step. - use_clipped_model_output (`bool`): TODO return_dict (`bool`): option for returning tuple rather than FlaxDDIMSchedulerOutput class Returns: @@ -253,7 +275,19 @@ def step( # 3. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + if self.config.prediction_type == "epsilon": + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + elif self.config.prediction_type == "sample": + pred_original_sample = model_output + elif self.config.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + # predict V + model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction`" + ) # 4. compute variance: "sigma_t(η)" -> see formula (16) # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) @@ -279,13 +313,11 @@ def add_noise( ) -> jnp.ndarray: sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() - while len(sqrt_alpha_prod.shape) < len(original_samples.shape): - sqrt_alpha_prod = sqrt_alpha_prod[:, None] + sqrt_alpha_prod = broadcast_to_shape_from_left(sqrt_alpha_prod, original_samples.shape) sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.0 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() - while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod[:, None] + sqrt_one_minus_alpha_prod = broadcast_to_shape_from_left(sqrt_one_minus_alpha_prod, original_samples.shape) noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index d51d58ac8f45..6f131659c2ba 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -21,8 +21,8 @@ import numpy as np import torch -from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput +from ..configuration_utils import ConfigMixin, FrozenDict, register_to_config +from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, deprecate from .scheduling_utils import SchedulerMixin @@ -80,8 +80,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and - [`~ConfigMixin.from_config`] functions. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. For more details, see the original paper: https://arxiv.org/abs/2006.11239 @@ -99,9 +99,14 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. clip_sample (`bool`, default `True`): option to clip predicted sample between -1 and 1 for numerical stability. - + prediction_type (`str`, default `epsilon`): + indicates whether the model predicts the noise (epsilon), or the samples. One of `epsilon`, `sample`. + `v-prediction` is not supported for this scheduler. """ + _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + _deprecated_kwargs = ["predict_epsilon"] + @register_to_config def __init__( self, @@ -112,7 +117,17 @@ def __init__( trained_betas: Optional[np.ndarray] = None, variance_type: str = "fixed_small", clip_sample: bool = True, + prediction_type: str = "epsilon", + **kwargs, ): + message = ( + "Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler =" + " DDPMScheduler.from_pretrained(, prediction_type='epsilon')`." + ) + predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs) + if predict_epsilon is not None: + self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample") + if trained_betas is not None: self.betas = torch.from_numpy(trained_betas) elif beta_schedule == "linear": @@ -192,6 +207,7 @@ def _get_variance(self, t, predicted_variance=None, variance_type=None): # for rl-diffuser https://arxiv.org/abs/2205.09991 elif variance_type == "fixed_small_log": variance = torch.log(torch.clamp(variance, min=1e-20)) + variance = torch.exp(0.5 * variance) elif variance_type == "fixed_large": variance = self.betas[t] elif variance_type == "fixed_large_log": @@ -212,9 +228,9 @@ def step( model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor, - predict_epsilon=True, generator=None, return_dict: bool = True, + **kwargs, ) -> Union[DDPMSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion @@ -225,8 +241,6 @@ def step( timestep (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. - predict_epsilon (`bool`): - optional flag to use when model predicts the samples directly instead of the noise, epsilon. generator: random number generator. return_dict (`bool`): option for returning tuple rather than DDPMSchedulerOutput class @@ -236,6 +250,16 @@ def step( returning a tuple, the first element is the sample tensor. """ + message = ( + "Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler =" + " DDPMScheduler.from_pretrained(, prediction_type='epsilon')`." + ) + predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs) + if predict_epsilon is not None: + new_config = dict(self.config) + new_config["prediction_type"] = "epsilon" if predict_epsilon else "sample" + self._internal_dict = FrozenDict(new_config) + t = timestep if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: @@ -251,10 +275,15 @@ def step( # 2. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf - if predict_epsilon: + if self.config.prediction_type == "epsilon": pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) - else: + elif self.config.prediction_type == "sample": pred_original_sample = model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` " + " for the DDPMScheduler." + ) # 3. Clip "predicted x_0" if self.config.clip_sample: @@ -272,10 +301,19 @@ def step( # 6. Add noise variance = 0 if t > 0: - noise = torch.randn( - model_output.size(), dtype=model_output.dtype, layout=model_output.layout, generator=generator - ).to(model_output.device) - variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * noise + device = model_output.device + if device.type == "mps": + # randn does not work reproducibly on mps + variance_noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator) + variance_noise = variance_noise.to(device) + else: + variance_noise = torch.randn( + model_output.shape, generator=generator, device=device, dtype=model_output.dtype + ) + if self.variance_type == "fixed_small_log": + variance = self._get_variance(t, predicted_variance=predicted_variance) * variance_noise + else: + variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * variance_noise pred_prev_sample = pred_prev_sample + variance diff --git a/src/diffusers/schedulers/scheduling_ddpm_flax.py b/src/diffusers/schedulers/scheduling_ddpm_flax.py index 7b3265611101..97b38fd3a17e 100644 --- a/src/diffusers/schedulers/scheduling_ddpm_flax.py +++ b/src/diffusers/schedulers/scheduling_ddpm_flax.py @@ -22,8 +22,14 @@ import jax.numpy as jnp from jax import random -from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput +from ..configuration_utils import ConfigMixin, FrozenDict, register_to_config +from ..utils import deprecate +from .scheduling_utils_flax import ( + _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, + FlaxSchedulerMixin, + FlaxSchedulerOutput, + broadcast_to_shape_from_left, +) def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray: @@ -78,8 +84,8 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin): [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and - [`~ConfigMixin.from_config`] functions. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. For more details, see the original paper: https://arxiv.org/abs/2006.11239 @@ -97,10 +103,18 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin): `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. clip_sample (`bool`, default `True`): option to clip predicted sample between -1 and 1 for numerical stability. - tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays. - + prediction_type (`str`, default `epsilon`): + indicates whether the model predicts the noise (epsilon), or the samples. One of `epsilon`, `sample`. + `v-prediction` is not supported for this scheduler. """ + _compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + _deprecated_kwargs = ["predict_epsilon"] + + @property + def has_state(self): + return True + @register_to_config def __init__( self, @@ -111,7 +125,17 @@ def __init__( trained_betas: Optional[jnp.ndarray] = None, variance_type: str = "fixed_small", clip_sample: bool = True, + prediction_type: str = "epsilon", + **kwargs, ): + message = ( + "Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler =" + " FlaxDDPMScheduler.from_pretrained(, prediction_type='epsilon')`." + ) + predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs) + if predict_epsilon is not None: + self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample") + if trained_betas is not None: self.betas = jnp.asarray(trained_betas) elif beta_schedule == "linear": @@ -129,11 +153,12 @@ def __init__( self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0) self.one = jnp.array(1.0) - self.state = DDPMSchedulerState.create(num_train_timesteps=num_train_timesteps) + def create_state(self): + return DDPMSchedulerState.create(num_train_timesteps=self.config.num_train_timesteps) - self.variance_type = variance_type - - def set_timesteps(self, state: DDPMSchedulerState, num_inference_steps: int, shape: Tuple) -> DDPMSchedulerState: + def set_timesteps( + self, state: DDPMSchedulerState, num_inference_steps: int, shape: Tuple = () + ) -> DDPMSchedulerState: """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. @@ -189,8 +214,8 @@ def step( timestep: int, sample: jnp.ndarray, key: random.KeyArray, - predict_epsilon: bool = True, return_dict: bool = True, + **kwargs, ) -> Union[FlaxDDPMSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion @@ -203,8 +228,6 @@ def step( sample (`jnp.ndarray`): current instance of sample being created by diffusion process. key (`random.KeyArray`): a PRNG key. - predict_epsilon (`bool`): - optional flag to use when model predicts the samples directly instead of the noise, epsilon. return_dict (`bool`): option for returning tuple rather than FlaxDDPMSchedulerOutput class Returns: @@ -212,9 +235,19 @@ def step( `tuple`. When returning a tuple, the first element is the sample tensor. """ + message = ( + "Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler =" + " FlaxDDPMScheduler.from_pretrained(, prediction_type='epsilon')`." + ) + predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs) + if predict_epsilon is not None: + new_config = dict(self.config) + new_config["prediction_type"] = "epsilon" if predict_epsilon else "sample" + self._internal_dict = FrozenDict(new_config) + t = timestep - if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: + if model_output.shape[1] == sample.shape[1] * 2 and self.config.variance_type in ["learned", "learned_range"]: model_output, predicted_variance = jnp.split(model_output, sample.shape[1], axis=1) else: predicted_variance = None @@ -227,10 +260,15 @@ def step( # 2. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf - if predict_epsilon: + if self.config.prediction_type == "epsilon": pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) - else: + elif self.config.prediction_type == "sample": pred_original_sample = model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` " + " for the FlaxDDPMScheduler." + ) # 3. Clip "predicted x_0" if self.config.clip_sample: @@ -267,13 +305,11 @@ def add_noise( ) -> jnp.ndarray: sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() - while len(sqrt_alpha_prod.shape) < len(original_samples.shape): - sqrt_alpha_prod = sqrt_alpha_prod[..., None] + sqrt_alpha_prod = broadcast_to_shape_from_left(sqrt_alpha_prod, original_samples.shape) sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() - while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod[..., None] + sqrt_one_minus_alpha_prod = broadcast_to_shape_from_left(sqrt_one_minus_alpha_prod, original_samples.shape) noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py new file mode 100644 index 000000000000..76dc7acc1b9f --- /dev/null +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -0,0 +1,527 @@ +# Copyright 2022 TSAIL Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver + +import math +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, deprecate +from .scheduling_utils import SchedulerMixin, SchedulerOutput + + +def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + + def alpha_bar(time_step): + return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): + """ + DPM-Solver (and the improved version DPM-Solver++) is a fast dedicated high-order solver for diffusion ODEs with + the convergence order guarantee. Empirically, sampling by DPM-Solver with only 20 steps can generate high-quality + samples, and it can generate quite good samples even in only 10 steps. + + For more details, see the original paper: https://arxiv.org/abs/2206.00927 and https://arxiv.org/abs/2211.01095 + + Currently, we support the multistep DPM-Solver for both noise prediction models and data prediction models. We + recommend to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling. + + We also support the "dynamic thresholding" method in Imagen (https://arxiv.org/abs/2205.11487). For pixel-space + diffusion models, you can set both `algorithm_type="dpmsolver++"` and `thresholding=True` to use the dynamic + thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models (such as + stable-diffusion). + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + beta_start (`float`): the starting `beta` value of inference. + beta_end (`float`): the final `beta` value. + beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, optional): + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + solver_order (`int`, default `2`): + the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided + sampling, and `solver_order=3` for unconditional sampling. + prediction_type (`str`, default `epsilon`): + indicates whether the model predicts the noise (epsilon), or the data / `x0`. One of `epsilon`, `sample`, + or `v-prediction`. + thresholding (`bool`, default `False`): + whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487). + For pixel-space diffusion models, you can set both `algorithm_type=dpmsolver++` and `thresholding=True` to + use the dynamic thresholding. Note that the thresholding method is unsuitable for latent-space diffusion + models (such as stable-diffusion). + dynamic_thresholding_ratio (`float`, default `0.995`): + the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen + (https://arxiv.org/abs/2205.11487). + sample_max_value (`float`, default `1.0`): + the threshold value for dynamic thresholding. Valid only when `thresholding=True` and + `algorithm_type="dpmsolver++`. + algorithm_type (`str`, default `dpmsolver++`): + the algorithm type for the solver. Either `dpmsolver` or `dpmsolver++`. The `dpmsolver` type implements the + algorithms in https://arxiv.org/abs/2206.00927, and the `dpmsolver++` type implements the algorithms in + https://arxiv.org/abs/2211.01095. We recommend to use `dpmsolver++` with `solver_order=2` for guided + sampling (e.g. stable-diffusion). + solver_type (`str`, default `midpoint`): + the solver type for the second-order solver. Either `midpoint` or `heun`. The solver type slightly affects + the sample quality, especially for small number of steps. We empirically find that `midpoint` solvers are + slightly better, so we recommend to use the `midpoint` type. + lower_order_final (`bool`, default `True`): + whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically + find this trick can stabilize the sampling of DPM-Solver for steps < 15, especially for steps <= 10. + + """ + + _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + _deprecated_kwargs = ["predict_epsilon"] + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[np.ndarray] = None, + solver_order: int = 2, + prediction_type: str = "epsilon", + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + algorithm_type: str = "dpmsolver++", + solver_type: str = "midpoint", + lower_order_final: bool = True, + **kwargs, + ): + message = ( + "Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler =" + " DPMSolverMultistepScheduler.from_pretrained(, prediction_type='epsilon')`." + ) + predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs) + if predict_epsilon is not None: + self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample") + + if trained_betas is not None: + self.betas = torch.from_numpy(trained_betas) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = ( + torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + ) + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + # Currently we only support VP-type noise schedule + self.alpha_t = torch.sqrt(self.alphas_cumprod) + self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) + self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # settings for DPM-Solver + if algorithm_type not in ["dpmsolver", "dpmsolver++"]: + raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}") + if solver_type not in ["midpoint", "heun"]: + raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") + + # setable values + self.num_inference_steps = None + timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy() + self.timesteps = torch.from_numpy(timesteps) + self.model_outputs = [None] * solver_order + self.lower_order_nums = 0 + + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + """ + Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, optional): + the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + self.num_inference_steps = num_inference_steps + timesteps = ( + np.linspace(0, self.num_train_timesteps - 1, num_inference_steps + 1) + .round()[::-1][:-1] + .copy() + .astype(np.int64) + ) + self.timesteps = torch.from_numpy(timesteps).to(device) + self.model_outputs = [ + None, + ] * self.config.solver_order + self.lower_order_nums = 0 + + def convert_model_output( + self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor + ) -> torch.FloatTensor: + """ + Convert the model output to the corresponding type that the algorithm (DPM-Solver / DPM-Solver++) needs. + + DPM-Solver is designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to + discretize an integral of the data prediction model. So we need to first convert the model output to the + corresponding type to match the algorithm. + + Note that the algorithm type and the model type is decoupled. That is to say, we can use either DPM-Solver or + DPM-Solver++ for both noise prediction model and data prediction model. + + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + + Returns: + `torch.FloatTensor`: the converted model output. + """ + # DPM-Solver++ needs to solve an integral of the data prediction model. + if self.config.algorithm_type == "dpmsolver++": + if self.config.prediction_type == "epsilon": + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + x0_pred = (sample - sigma_t * model_output) / alpha_t + elif self.config.prediction_type == "sample": + x0_pred = model_output + elif self.config.prediction_type == "v_prediction": + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + x0_pred = alpha_t * sample - sigma_t * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction` for the DPMSolverMultistepScheduler." + ) + + if self.config.thresholding: + # Dynamic thresholding in https://arxiv.org/abs/2205.11487 + dynamic_max_val = torch.quantile( + torch.abs(x0_pred).reshape((x0_pred.shape[0], -1)), self.config.dynamic_thresholding_ratio, dim=1 + ) + dynamic_max_val = torch.maximum( + dynamic_max_val, + self.config.sample_max_value * torch.ones_like(dynamic_max_val).to(dynamic_max_val.device), + )[(...,) + (None,) * (x0_pred.ndim - 1)] + x0_pred = torch.clamp(x0_pred, -dynamic_max_val, dynamic_max_val) / dynamic_max_val + return x0_pred + # DPM-Solver needs to solve an integral of the noise prediction model. + elif self.config.algorithm_type == "dpmsolver": + if self.config.prediction_type == "epsilon": + return model_output + elif self.config.prediction_type == "sample": + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + epsilon = (sample - alpha_t * model_output) / sigma_t + return epsilon + elif self.config.prediction_type == "v_prediction": + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + epsilon = alpha_t * model_output + sigma_t * sample + return epsilon + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction` for the DPMSolverMultistepScheduler." + ) + + def dpm_solver_first_order_update( + self, + model_output: torch.FloatTensor, + timestep: int, + prev_timestep: int, + sample: torch.FloatTensor, + ) -> torch.FloatTensor: + """ + One step for the first-order DPM-Solver (equivalent to DDIM). + + See https://arxiv.org/abs/2206.00927 for the detailed derivation. + + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + prev_timestep (`int`): previous discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + + Returns: + `torch.FloatTensor`: the sample tensor at the previous timestep. + """ + lambda_t, lambda_s = self.lambda_t[prev_timestep], self.lambda_t[timestep] + alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep] + sigma_t, sigma_s = self.sigma_t[prev_timestep], self.sigma_t[timestep] + h = lambda_t - lambda_s + if self.config.algorithm_type == "dpmsolver++": + x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output + elif self.config.algorithm_type == "dpmsolver": + x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output + return x_t + + def multistep_dpm_solver_second_order_update( + self, + model_output_list: List[torch.FloatTensor], + timestep_list: List[int], + prev_timestep: int, + sample: torch.FloatTensor, + ) -> torch.FloatTensor: + """ + One step for the second-order multistep DPM-Solver. + + Args: + model_output_list (`List[torch.FloatTensor]`): + direct outputs from learned diffusion model at current and latter timesteps. + timestep (`int`): current and latter discrete timestep in the diffusion chain. + prev_timestep (`int`): previous discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + + Returns: + `torch.FloatTensor`: the sample tensor at the previous timestep. + """ + t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2] + m0, m1 = model_output_list[-1], model_output_list[-2] + lambda_t, lambda_s0, lambda_s1 = self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1] + alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] + sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] + h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 + r0 = h_0 / h + D0, D1 = m0, (1.0 / r0) * (m0 - m1) + if self.config.algorithm_type == "dpmsolver++": + # See https://arxiv.org/abs/2211.01095 for detailed derivations + if self.config.solver_type == "midpoint": + x_t = ( + (sigma_t / sigma_s0) * sample + - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1 + ) + elif self.config.solver_type == "heun": + x_t = ( + (sigma_t / sigma_s0) * sample + - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 + ) + elif self.config.algorithm_type == "dpmsolver": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + if self.config.solver_type == "midpoint": + x_t = ( + (alpha_t / alpha_s0) * sample + - (sigma_t * (torch.exp(h) - 1.0)) * D0 + - 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1 + ) + elif self.config.solver_type == "heun": + x_t = ( + (alpha_t / alpha_s0) * sample + - (sigma_t * (torch.exp(h) - 1.0)) * D0 + - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + ) + return x_t + + def multistep_dpm_solver_third_order_update( + self, + model_output_list: List[torch.FloatTensor], + timestep_list: List[int], + prev_timestep: int, + sample: torch.FloatTensor, + ) -> torch.FloatTensor: + """ + One step for the third-order multistep DPM-Solver. + + Args: + model_output_list (`List[torch.FloatTensor]`): + direct outputs from learned diffusion model at current and latter timesteps. + timestep (`int`): current and latter discrete timestep in the diffusion chain. + prev_timestep (`int`): previous discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + + Returns: + `torch.FloatTensor`: the sample tensor at the previous timestep. + """ + t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3] + m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] + lambda_t, lambda_s0, lambda_s1, lambda_s2 = ( + self.lambda_t[t], + self.lambda_t[s0], + self.lambda_t[s1], + self.lambda_t[s2], + ) + alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] + sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] + h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2 + r0, r1 = h_0 / h, h_1 / h + D0 = m0 + D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2) + D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) + D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1) + if self.config.algorithm_type == "dpmsolver++": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + x_t = ( + (sigma_t / sigma_s0) * sample + - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 + - (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2 + ) + elif self.config.algorithm_type == "dpmsolver": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + x_t = ( + (alpha_t / alpha_s0) * sample + - (sigma_t * (torch.exp(h) - 1.0)) * D0 + - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + - (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2 + ) + return x_t + + def step( + self, + model_output: torch.FloatTensor, + timestep: int, + sample: torch.FloatTensor, + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Step function propagating the sample with the multistep DPM-Solver. + + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + Returns: + [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is + True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + step_index = (self.timesteps == timestep).nonzero() + if len(step_index) == 0: + step_index = len(self.timesteps) - 1 + else: + step_index = step_index.item() + prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1] + lower_order_final = ( + (step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15 + ) + lower_order_second = ( + (step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 + ) + + model_output = self.convert_model_output(model_output, timestep, sample) + for i in range(self.config.solver_order - 1): + self.model_outputs[i] = self.model_outputs[i + 1] + self.model_outputs[-1] = model_output + + if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final: + prev_sample = self.dpm_solver_first_order_update(model_output, timestep, prev_timestep, sample) + elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: + timestep_list = [self.timesteps[step_index - 1], timestep] + prev_sample = self.multistep_dpm_solver_second_order_update( + self.model_outputs, timestep_list, prev_timestep, sample + ) + else: + timestep_list = [self.timesteps[step_index - 2], self.timesteps[step_index - 1], timestep] + prev_sample = self.multistep_dpm_solver_third_order_update( + self.model_outputs, timestep_list, prev_timestep, sample + ) + + if self.lower_order_nums < self.config.solver_order: + self.lower_order_nums += 1 + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.FloatTensor`): input sample + + Returns: + `torch.FloatTensor`: scaled input sample + """ + return sample + + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor, + ) -> torch.FloatTensor: + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) + + sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py new file mode 100644 index 000000000000..78b611ae2721 --- /dev/null +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py @@ -0,0 +1,625 @@ +# Copyright 2022 TSAIL Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver + +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import flax +import jax +import jax.numpy as jnp + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import deprecate +from .scheduling_utils_flax import ( + _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, + FlaxSchedulerMixin, + FlaxSchedulerOutput, + broadcast_to_shape_from_left, +) + + +def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999) -> jnp.ndarray: + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + + Returns: + betas (`jnp.ndarray`): the betas used by the scheduler to step the model outputs + """ + + def alpha_bar(time_step): + return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return jnp.array(betas, dtype=jnp.float32) + + +@flax.struct.dataclass +class DPMSolverMultistepSchedulerState: + # setable values + num_inference_steps: Optional[int] = None + timesteps: Optional[jnp.ndarray] = None + + # running values + model_outputs: Optional[jnp.ndarray] = None + lower_order_nums: Optional[int] = None + step_index: Optional[int] = None + prev_timestep: Optional[int] = None + cur_sample: Optional[jnp.ndarray] = None + + @classmethod + def create(cls, num_train_timesteps: int): + return cls(timesteps=jnp.arange(0, num_train_timesteps)[::-1]) + + +@dataclass +class FlaxDPMSolverMultistepSchedulerOutput(FlaxSchedulerOutput): + state: DPMSolverMultistepSchedulerState + + +class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin): + """ + DPM-Solver (and the improved version DPM-Solver++) is a fast dedicated high-order solver for diffusion ODEs with + the convergence order guarantee. Empirically, sampling by DPM-Solver with only 20 steps can generate high-quality + samples, and it can generate quite good samples even in only 10 steps. + + For more details, see the original paper: https://arxiv.org/abs/2206.00927 and https://arxiv.org/abs/2211.01095 + + Currently, we support the multistep DPM-Solver for both noise prediction models and data prediction models. We + recommend to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling. + + We also support the "dynamic thresholding" method in Imagen (https://arxiv.org/abs/2205.11487). For pixel-space + diffusion models, you can set both `algorithm_type="dpmsolver++"` and `thresholding=True` to use the dynamic + thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models (such as + stable-diffusion). + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. + + For more details, see the original paper: https://arxiv.org/abs/2206.00927 and https://arxiv.org/abs/2211.01095 + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + beta_start (`float`): the starting `beta` value of inference. + beta_end (`float`): the final `beta` value. + beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, optional): + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + solver_order (`int`, default `2`): + the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided + sampling, and `solver_order=3` for unconditional sampling. + prediction_type (`str`, default `epsilon`): + indicates whether the model predicts the noise (epsilon), or the data / `x0`. One of `epsilon`, `sample`, + or `v-prediction`. + thresholding (`bool`, default `False`): + whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487). + For pixel-space diffusion models, you can set both `algorithm_type=dpmsolver++` and `thresholding=True` to + use the dynamic thresholding. Note that the thresholding method is unsuitable for latent-space diffusion + models (such as stable-diffusion). + dynamic_thresholding_ratio (`float`, default `0.995`): + the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen + (https://arxiv.org/abs/2205.11487). + sample_max_value (`float`, default `1.0`): + the threshold value for dynamic thresholding. Valid only when `thresholding=True` and + `algorithm_type="dpmsolver++`. + algorithm_type (`str`, default `dpmsolver++`): + the algorithm type for the solver. Either `dpmsolver` or `dpmsolver++`. The `dpmsolver` type implements the + algorithms in https://arxiv.org/abs/2206.00927, and the `dpmsolver++` type implements the algorithms in + https://arxiv.org/abs/2211.01095. We recommend to use `dpmsolver++` with `solver_order=2` for guided + sampling (e.g. stable-diffusion). + solver_type (`str`, default `midpoint`): + the solver type for the second-order solver. Either `midpoint` or `heun`. The solver type slightly affects + the sample quality, especially for small number of steps. We empirically find that `midpoint` solvers are + slightly better, so we recommend to use the `midpoint` type. + lower_order_final (`bool`, default `True`): + whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically + find this trick can stabilize the sampling of DPM-Solver for steps < 15, especially for steps <= 10. + + """ + + _compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + _deprecated_kwargs = ["predict_epsilon"] + + @property + def has_state(self): + return True + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[jnp.ndarray] = None, + solver_order: int = 2, + prediction_type: str = "epsilon", + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + algorithm_type: str = "dpmsolver++", + solver_type: str = "midpoint", + lower_order_final: bool = True, + **kwargs, + ): + message = ( + "Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler =" + " FlaxDPMSolverMultistepScheduler.from_pretrained(, prediction_type='epsilon')`." + ) + predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs) + if predict_epsilon is not None: + self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample") + + if trained_betas is not None: + self.betas = jnp.asarray(trained_betas) + elif beta_schedule == "linear": + self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = jnp.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=jnp.float32) ** 2 + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0) + # Currently we only support VP-type noise schedule + self.alpha_t = jnp.sqrt(self.alphas_cumprod) + self.sigma_t = jnp.sqrt(1 - self.alphas_cumprod) + self.lambda_t = jnp.log(self.alpha_t) - jnp.log(self.sigma_t) + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # settings for DPM-Solver + if algorithm_type not in ["dpmsolver", "dpmsolver++"]: + raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}") + if solver_type not in ["midpoint", "heun"]: + raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") + + def create_state(self): + return DPMSolverMultistepSchedulerState.create(num_train_timesteps=self.config.num_train_timesteps) + + def set_timesteps( + self, state: DPMSolverMultistepSchedulerState, num_inference_steps: int, shape: Tuple + ) -> DPMSolverMultistepSchedulerState: + """ + Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + state (`DPMSolverMultistepSchedulerState`): + the `FlaxDPMSolverMultistepScheduler` state data class instance. + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + shape (`Tuple`): + the shape of the samples to be generated. + """ + timesteps = ( + jnp.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1) + .round()[::-1][:-1] + .astype(jnp.int32) + ) + + return state.replace( + num_inference_steps=num_inference_steps, + timesteps=timesteps, + model_outputs=jnp.zeros((self.config.solver_order,) + shape), + lower_order_nums=0, + step_index=0, + prev_timestep=-1, + cur_sample=jnp.zeros(shape), + ) + + def convert_model_output( + self, + model_output: jnp.ndarray, + timestep: int, + sample: jnp.ndarray, + ) -> jnp.ndarray: + """ + Convert the model output to the corresponding type that the algorithm (DPM-Solver / DPM-Solver++) needs. + + DPM-Solver is designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to + discretize an integral of the data prediction model. So we need to first convert the model output to the + corresponding type to match the algorithm. + + Note that the algorithm type and the model type is decoupled. That is to say, we can use either DPM-Solver or + DPM-Solver++ for both noise prediction model and data prediction model. + + Args: + model_output (`jnp.ndarray`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`jnp.ndarray`): + current instance of sample being created by diffusion process. + + Returns: + `jnp.ndarray`: the converted model output. + """ + # DPM-Solver++ needs to solve an integral of the data prediction model. + if self.config.algorithm_type == "dpmsolver++": + if self.config.prediction_type == "epsilon": + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + x0_pred = (sample - sigma_t * model_output) / alpha_t + elif self.config.prediction_type == "sample": + x0_pred = model_output + elif self.config.prediction_type == "v_prediction": + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + x0_pred = alpha_t * sample - sigma_t * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, " + " or `v_prediction` for the FlaxDPMSolverMultistepScheduler." + ) + + if self.config.thresholding: + # Dynamic thresholding in https://arxiv.org/abs/2205.11487 + dynamic_max_val = jnp.percentile( + jnp.abs(x0_pred), self.config.dynamic_thresholding_ratio, axis=tuple(range(1, x0_pred.ndim)) + ) + dynamic_max_val = jnp.maximum( + dynamic_max_val, self.config.sample_max_value * jnp.ones_like(dynamic_max_val) + ) + x0_pred = jnp.clip(x0_pred, -dynamic_max_val, dynamic_max_val) / dynamic_max_val + return x0_pred + # DPM-Solver needs to solve an integral of the noise prediction model. + elif self.config.algorithm_type == "dpmsolver": + if self.config.prediction_type == "epsilon": + return model_output + elif self.config.prediction_type == "sample": + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + epsilon = (sample - alpha_t * model_output) / sigma_t + return epsilon + elif self.config.prediction_type == "v_prediction": + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + epsilon = alpha_t * model_output + sigma_t * sample + return epsilon + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, " + " or `v_prediction` for the FlaxDPMSolverMultistepScheduler." + ) + + def dpm_solver_first_order_update( + self, model_output: jnp.ndarray, timestep: int, prev_timestep: int, sample: jnp.ndarray + ) -> jnp.ndarray: + """ + One step for the first-order DPM-Solver (equivalent to DDIM). + + See https://arxiv.org/abs/2206.00927 for the detailed derivation. + + Args: + model_output (`jnp.ndarray`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + prev_timestep (`int`): previous discrete timestep in the diffusion chain. + sample (`jnp.ndarray`): + current instance of sample being created by diffusion process. + + Returns: + `jnp.ndarray`: the sample tensor at the previous timestep. + """ + t, s0 = prev_timestep, timestep + m0 = model_output + lambda_t, lambda_s = self.lambda_t[t], self.lambda_t[s0] + alpha_t, alpha_s = self.alpha_t[t], self.alpha_t[s0] + sigma_t, sigma_s = self.sigma_t[t], self.sigma_t[s0] + h = lambda_t - lambda_s + if self.config.algorithm_type == "dpmsolver++": + x_t = (sigma_t / sigma_s) * sample - (alpha_t * (jnp.exp(-h) - 1.0)) * m0 + elif self.config.algorithm_type == "dpmsolver": + x_t = (alpha_t / alpha_s) * sample - (sigma_t * (jnp.exp(h) - 1.0)) * m0 + return x_t + + def multistep_dpm_solver_second_order_update( + self, + model_output_list: jnp.ndarray, + timestep_list: List[int], + prev_timestep: int, + sample: jnp.ndarray, + ) -> jnp.ndarray: + """ + One step for the second-order multistep DPM-Solver. + + Args: + model_output_list (`List[jnp.ndarray]`): + direct outputs from learned diffusion model at current and latter timesteps. + timestep (`int`): current and latter discrete timestep in the diffusion chain. + prev_timestep (`int`): previous discrete timestep in the diffusion chain. + sample (`jnp.ndarray`): + current instance of sample being created by diffusion process. + + Returns: + `jnp.ndarray`: the sample tensor at the previous timestep. + """ + t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2] + m0, m1 = model_output_list[-1], model_output_list[-2] + lambda_t, lambda_s0, lambda_s1 = self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1] + alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] + sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] + h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 + r0 = h_0 / h + D0, D1 = m0, (1.0 / r0) * (m0 - m1) + if self.config.algorithm_type == "dpmsolver++": + # See https://arxiv.org/abs/2211.01095 for detailed derivations + if self.config.solver_type == "midpoint": + x_t = ( + (sigma_t / sigma_s0) * sample + - (alpha_t * (jnp.exp(-h) - 1.0)) * D0 + - 0.5 * (alpha_t * (jnp.exp(-h) - 1.0)) * D1 + ) + elif self.config.solver_type == "heun": + x_t = ( + (sigma_t / sigma_s0) * sample + - (alpha_t * (jnp.exp(-h) - 1.0)) * D0 + + (alpha_t * ((jnp.exp(-h) - 1.0) / h + 1.0)) * D1 + ) + elif self.config.algorithm_type == "dpmsolver": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + if self.config.solver_type == "midpoint": + x_t = ( + (alpha_t / alpha_s0) * sample + - (sigma_t * (jnp.exp(h) - 1.0)) * D0 + - 0.5 * (sigma_t * (jnp.exp(h) - 1.0)) * D1 + ) + elif self.config.solver_type == "heun": + x_t = ( + (alpha_t / alpha_s0) * sample + - (sigma_t * (jnp.exp(h) - 1.0)) * D0 + - (sigma_t * ((jnp.exp(h) - 1.0) / h - 1.0)) * D1 + ) + return x_t + + def multistep_dpm_solver_third_order_update( + self, + model_output_list: jnp.ndarray, + timestep_list: List[int], + prev_timestep: int, + sample: jnp.ndarray, + ) -> jnp.ndarray: + """ + One step for the third-order multistep DPM-Solver. + + Args: + model_output_list (`List[jnp.ndarray]`): + direct outputs from learned diffusion model at current and latter timesteps. + timestep (`int`): current and latter discrete timestep in the diffusion chain. + prev_timestep (`int`): previous discrete timestep in the diffusion chain. + sample (`jnp.ndarray`): + current instance of sample being created by diffusion process. + + Returns: + `jnp.ndarray`: the sample tensor at the previous timestep. + """ + t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3] + m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] + lambda_t, lambda_s0, lambda_s1, lambda_s2 = ( + self.lambda_t[t], + self.lambda_t[s0], + self.lambda_t[s1], + self.lambda_t[s2], + ) + alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] + sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] + h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2 + r0, r1 = h_0 / h, h_1 / h + D0 = m0 + D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2) + D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) + D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1) + if self.config.algorithm_type == "dpmsolver++": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + x_t = ( + (sigma_t / sigma_s0) * sample + - (alpha_t * (jnp.exp(-h) - 1.0)) * D0 + + (alpha_t * ((jnp.exp(-h) - 1.0) / h + 1.0)) * D1 + - (alpha_t * ((jnp.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2 + ) + elif self.config.algorithm_type == "dpmsolver": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + x_t = ( + (alpha_t / alpha_s0) * sample + - (sigma_t * (jnp.exp(h) - 1.0)) * D0 + - (sigma_t * ((jnp.exp(h) - 1.0) / h - 1.0)) * D1 + - (sigma_t * ((jnp.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2 + ) + return x_t + + def step( + self, + state: DPMSolverMultistepSchedulerState, + model_output: jnp.ndarray, + timestep: int, + sample: jnp.ndarray, + return_dict: bool = True, + ) -> Union[FlaxDPMSolverMultistepSchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by DPM-Solver. Core function to propagate the diffusion process + from the learned model outputs (most often the predicted noise). + + Args: + state (`DPMSolverMultistepSchedulerState`): + the `FlaxDPMSolverMultistepScheduler` state data class instance. + model_output (`jnp.ndarray`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`jnp.ndarray`): + current instance of sample being created by diffusion process. + return_dict (`bool`): option for returning tuple rather than FlaxDPMSolverMultistepSchedulerOutput class + + Returns: + [`FlaxDPMSolverMultistepSchedulerOutput`] or `tuple`: [`FlaxDPMSolverMultistepSchedulerOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. + + """ + prev_timestep = jax.lax.cond( + state.step_index == len(state.timesteps) - 1, + lambda _: 0, + lambda _: state.timesteps[state.step_index + 1], + (), + ) + + model_output = self.convert_model_output(model_output, timestep, sample) + + model_outputs_new = jnp.roll(state.model_outputs, -1, axis=0) + model_outputs_new = model_outputs_new.at[-1].set(model_output) + state = state.replace( + model_outputs=model_outputs_new, + prev_timestep=prev_timestep, + cur_sample=sample, + ) + + def step_1(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray: + return self.dpm_solver_first_order_update( + state.model_outputs[-1], + state.timesteps[state.step_index], + state.prev_timestep, + state.cur_sample, + ) + + def step_23(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray: + def step_2(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray: + timestep_list = jnp.array([state.timesteps[state.step_index - 1], state.timesteps[state.step_index]]) + return self.multistep_dpm_solver_second_order_update( + state.model_outputs, + timestep_list, + state.prev_timestep, + state.cur_sample, + ) + + def step_3(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray: + timestep_list = jnp.array( + [ + state.timesteps[state.step_index - 2], + state.timesteps[state.step_index - 1], + state.timesteps[state.step_index], + ] + ) + return self.multistep_dpm_solver_third_order_update( + state.model_outputs, + timestep_list, + state.prev_timestep, + state.cur_sample, + ) + + if self.config.solver_order == 2: + return step_2(state) + elif self.config.lower_order_final and len(state.timesteps) < 15: + return jax.lax.cond( + state.lower_order_nums < 2, + step_2, + lambda state: jax.lax.cond( + state.step_index == len(state.timesteps) - 2, + step_2, + step_3, + state, + ), + state, + ) + else: + return jax.lax.cond( + state.lower_order_nums < 2, + step_2, + step_3, + state, + ) + + if self.config.solver_order == 1: + prev_sample = step_1(state) + elif self.config.lower_order_final and len(state.timesteps) < 15: + prev_sample = jax.lax.cond( + state.lower_order_nums < 1, + step_1, + lambda state: jax.lax.cond( + state.step_index == len(state.timesteps) - 1, + step_1, + step_23, + state, + ), + state, + ) + else: + prev_sample = jax.lax.cond( + state.lower_order_nums < 1, + step_1, + step_23, + state, + ) + + state = state.replace( + lower_order_nums=jnp.minimum(state.lower_order_nums + 1, self.config.solver_order), + step_index=(state.step_index + 1), + ) + + if not return_dict: + return (prev_sample, state) + + return FlaxDPMSolverMultistepSchedulerOutput(prev_sample=prev_sample, state=state) + + def scale_model_input( + self, state: DPMSolverMultistepSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None + ) -> jnp.ndarray: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + state (`DPMSolverMultistepSchedulerState`): + the `FlaxDPMSolverMultistepScheduler` state data class instance. + sample (`jnp.ndarray`): input sample + timestep (`int`, optional): current timestep + + Returns: + `jnp.ndarray`: scaled input sample + """ + return sample + + def add_noise( + self, + original_samples: jnp.ndarray, + noise: jnp.ndarray, + timesteps: jnp.ndarray, + ) -> jnp.ndarray: + sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + sqrt_alpha_prod = broadcast_to_shape_from_left(sqrt_alpha_prod, original_samples.shape) + + sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.0 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + sqrt_one_minus_alpha_prod = broadcast_to_shape_from_left(sqrt_one_minus_alpha_prod, original_samples.shape) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py new file mode 100644 index 000000000000..8117f305606a --- /dev/null +++ b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py @@ -0,0 +1,264 @@ +# Copyright 2022 Katherine Crowson and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, logging +from .scheduling_utils import SchedulerMixin + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->EulerAncestralDiscrete +class EulerAncestralDiscreteSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's step function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample (x_{0}) based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: torch.FloatTensor + pred_original_sample: Optional[torch.FloatTensor] = None + + +class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): + """ + Ancestral sampling with Euler method steps. Based on the original k-diffusion implementation by Katherine Crowson: + https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L72 + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + beta_start (`float`): the starting `beta` value of inference. + beta_end (`float`): the final `beta` value. + beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear` or `scaled_linear`. + trained_betas (`np.ndarray`, optional): + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + + """ + + _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[np.ndarray] = None, + ): + if trained_betas is not None: + self.betas = torch.from_numpy(trained_betas) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = ( + torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + ) + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) + self.sigmas = torch.from_numpy(sigmas) + + # standard deviation of the initial noise distribution + self.init_noise_sigma = self.sigmas.max() + + # setable values + self.num_inference_steps = None + timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy() + self.timesteps = torch.from_numpy(timesteps) + self.is_scale_input_called = False + + def scale_model_input( + self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] + ) -> torch.FloatTensor: + """ + Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm. + + Args: + sample (`torch.FloatTensor`): input sample + timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain + + Returns: + `torch.FloatTensor`: scaled input sample + """ + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + step_index = (self.timesteps == timestep).nonzero().item() + sigma = self.sigmas[step_index] + sample = sample / ((sigma**2 + 1) ** 0.5) + self.is_scale_input_called = True + return sample + + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + """ + Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, optional): + the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + self.num_inference_steps = num_inference_steps + + timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) + sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) + self.sigmas = torch.from_numpy(sigmas).to(device=device) + if str(device).startswith("mps"): + # mps does not support float64 + self.timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32) + else: + self.timesteps = torch.from_numpy(timesteps).to(device=device) + + def step( + self, + model_output: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + sample: torch.FloatTensor, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + timestep (`float`): current timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + generator (`torch.Generator`, optional): Random number generator. + return_dict (`bool`): option for returning tuple rather than EulerAncestralDiscreteSchedulerOutput class + + Returns: + [`~schedulers.scheduling_utils.EulerAncestralDiscreteSchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.EulerAncestralDiscreteSchedulerOutput`] if `return_dict` is True, otherwise + a `tuple`. When returning a tuple, the first element is the sample tensor. + + """ + + if ( + isinstance(timestep, int) + or isinstance(timestep, torch.IntTensor) + or isinstance(timestep, torch.LongTensor) + ): + raise ValueError( + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" + " one of the `scheduler.timesteps` as a timestep.", + ) + + if not self.is_scale_input_called: + logger.warning( + "The `scale_model_input` function should be called before `step` to ensure correct denoising. " + "See `StableDiffusionPipeline` for a usage example." + ) + + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + + step_index = (self.timesteps == timestep).nonzero().item() + sigma = self.sigmas[step_index] + + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + pred_original_sample = sample - sigma * model_output + sigma_from = self.sigmas[step_index] + sigma_to = self.sigmas[step_index + 1] + sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5 + sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 + + # 2. Convert to an ODE derivative + derivative = (sample - pred_original_sample) / sigma + + dt = sigma_down - sigma + + prev_sample = sample + derivative * dt + + device = model_output.device + if device.type == "mps": + # randn does not work reproducibly on mps + noise = torch.randn(model_output.shape, dtype=model_output.dtype, device="cpu", generator=generator).to( + device + ) + else: + noise = torch.randn(model_output.shape, dtype=model_output.dtype, device=device, generator=generator).to( + device + ) + + prev_sample = prev_sample + noise * sigma_up + + if not return_dict: + return (prev_sample,) + + return EulerAncestralDiscreteSchedulerOutput( + prev_sample=prev_sample, pred_original_sample=pred_original_sample + ) + + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.FloatTensor, + ) -> torch.FloatTensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + # mps does not support float64 + self.timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + self.timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + schedule_timesteps = self.timesteps + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = self.sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + noisy_samples = original_samples + noise * sigma + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py new file mode 100644 index 000000000000..4b7b2909e7d4 --- /dev/null +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -0,0 +1,282 @@ +# Copyright 2022 Katherine Crowson and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, logging +from .scheduling_utils import SchedulerMixin + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->EulerDiscrete +class EulerDiscreteSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's step function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample (x_{0}) based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: torch.FloatTensor + pred_original_sample: Optional[torch.FloatTensor] = None + + +class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): + """ + Euler scheduler (Algorithm 2) from Karras et al. (2022) https://arxiv.org/abs/2206.00364. . Based on the original + k-diffusion implementation by Katherine Crowson: + https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L51 + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + beta_start (`float`): the starting `beta` value of inference. + beta_end (`float`): the final `beta` value. + beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear` or `scaled_linear`. + trained_betas (`np.ndarray`, optional): + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + + """ + + _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[np.ndarray] = None, + prediction_type: str = "epsilon", + ): + if trained_betas is not None: + self.betas = torch.from_numpy(trained_betas) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = ( + torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + ) + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) + self.sigmas = torch.from_numpy(sigmas) + + # standard deviation of the initial noise distribution + self.init_noise_sigma = self.sigmas.max() + + # setable values + self.num_inference_steps = None + timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy() + self.timesteps = torch.from_numpy(timesteps) + self.is_scale_input_called = False + + def scale_model_input( + self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] + ) -> torch.FloatTensor: + """ + Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm. + + Args: + sample (`torch.FloatTensor`): input sample + timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain + + Returns: + `torch.FloatTensor`: scaled input sample + """ + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + step_index = (self.timesteps == timestep).nonzero().item() + sigma = self.sigmas[step_index] + sample = sample / ((sigma**2 + 1) ** 0.5) + self.is_scale_input_called = True + return sample + + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + """ + Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, optional): + the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + self.num_inference_steps = num_inference_steps + + timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) + sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) + self.sigmas = torch.from_numpy(sigmas).to(device=device) + if str(device).startswith("mps"): + # mps does not support float64 + self.timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32) + else: + self.timesteps = torch.from_numpy(timesteps).to(device=device) + + def step( + self, + model_output: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + sample: torch.FloatTensor, + s_churn: float = 0.0, + s_tmin: float = 0.0, + s_tmax: float = float("inf"), + s_noise: float = 1.0, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[EulerDiscreteSchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + timestep (`float`): current timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + s_churn (`float`) + s_tmin (`float`) + s_tmax (`float`) + s_noise (`float`) + generator (`torch.Generator`, optional): Random number generator. + return_dict (`bool`): option for returning tuple rather than EulerDiscreteSchedulerOutput class + + Returns: + [`~schedulers.scheduling_utils.EulerDiscreteSchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.EulerDiscreteSchedulerOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is the sample tensor. + + """ + + if ( + isinstance(timestep, int) + or isinstance(timestep, torch.IntTensor) + or isinstance(timestep, torch.LongTensor) + ): + raise ValueError( + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" + " one of the `scheduler.timesteps` as a timestep.", + ) + + if not self.is_scale_input_called: + logger.warning( + "The `scale_model_input` function should be called before `step` to ensure correct denoising. " + "See `StableDiffusionPipeline` for a usage example." + ) + + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + + step_index = (self.timesteps == timestep).nonzero().item() + sigma = self.sigmas[step_index] + + gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0 + + device = model_output.device + if device.type == "mps": + # randn does not work reproducibly on mps + noise = torch.randn(model_output.shape, dtype=model_output.dtype, device="cpu", generator=generator).to( + device + ) + else: + noise = torch.randn(model_output.shape, dtype=model_output.dtype, device=device, generator=generator).to( + device + ) + + eps = noise * s_noise + sigma_hat = sigma * (gamma + 1) + + if gamma > 0: + sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5 + + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + if self.config.prediction_type == "epsilon": + pred_original_sample = sample - sigma_hat * model_output + elif self.config.prediction_type == "v_prediction": + # * c_out + input * c_skip + pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" + ) + + # 2. Convert to an ODE derivative + derivative = (sample - pred_original_sample) / sigma_hat + + dt = self.sigmas[step_index + 1] - sigma_hat + + prev_sample = sample + derivative * dt + + if not return_dict: + return (prev_sample,) + + return EulerDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) + + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.FloatTensor, + ) -> torch.FloatTensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + # mps does not support float64 + self.timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + self.timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + schedule_timesteps = self.timesteps + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = self.sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + noisy_samples = original_samples + noise * sigma + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/src/diffusers/schedulers/scheduling_ipndm.py b/src/diffusers/schedulers/scheduling_ipndm.py new file mode 100644 index 000000000000..e5495713a834 --- /dev/null +++ b/src/diffusers/schedulers/scheduling_ipndm.py @@ -0,0 +1,152 @@ +# Copyright 2022 Zhejiang University Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Tuple, Union + +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from .scheduling_utils import SchedulerMixin, SchedulerOutput + + +class IPNDMScheduler(SchedulerMixin, ConfigMixin): + """ + Improved Pseudo numerical methods for diffusion models (iPNDM) ported from @crowsonkb's amazing k-diffusion + [library](https://github.com/crowsonkb/v-diffusion-pytorch/blob/987f8985e38208345c1959b0ea767a625831cc9b/diffusion/sampling.py#L296) + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. + + For more details, see the original paper: https://arxiv.org/abs/2202.09778 + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + """ + + @register_to_config + def __init__(self, num_train_timesteps: int = 1000): + # set `betas`, `alphas`, `timesteps` + self.set_timesteps(num_train_timesteps) + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # For now we only support F-PNDM, i.e. the runge-kutta method + # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf + # mainly at formula (9), (12), (13) and the Algorithm 2. + self.pndm_order = 4 + + # running values + self.ets = [] + + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + """ + Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + """ + self.num_inference_steps = num_inference_steps + steps = torch.linspace(1, 0, num_inference_steps + 1)[:-1] + steps = torch.cat([steps, torch.tensor([0.0])]) + + self.betas = torch.sin(steps * math.pi / 2) ** 2 + self.alphas = (1.0 - self.betas**2) ** 0.5 + + timesteps = (torch.atan2(self.betas, self.alphas) / math.pi * 2)[:-1] + self.timesteps = timesteps.to(device) + + self.ets = [] + + def step( + self, + model_output: torch.FloatTensor, + timestep: int, + sample: torch.FloatTensor, + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple + times to approximate the solution. + + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + Returns: + [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is + True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + timestep_index = (self.timesteps == timestep).nonzero().item() + prev_timestep_index = timestep_index + 1 + + ets = sample * self.betas[timestep_index] + model_output * self.alphas[timestep_index] + self.ets.append(ets) + + if len(self.ets) == 1: + ets = self.ets[-1] + elif len(self.ets) == 2: + ets = (3 * self.ets[-1] - self.ets[-2]) / 2 + elif len(self.ets) == 3: + ets = (23 * self.ets[-1] - 16 * self.ets[-2] + 5 * self.ets[-3]) / 12 + else: + ets = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4]) + + prev_sample = self._get_prev_sample(sample, timestep_index, prev_timestep_index, ets) + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.FloatTensor`): input sample + + Returns: + `torch.FloatTensor`: scaled input sample + """ + return sample + + def _get_prev_sample(self, sample, timestep_index, prev_timestep_index, ets): + alpha = self.alphas[timestep_index] + sigma = self.betas[timestep_index] + + next_alpha = self.alphas[prev_timestep_index] + next_sigma = self.betas[prev_timestep_index] + + pred = (sample - sigma * ets) / max(alpha, 1e-8) + prev_sample = next_alpha * pred + ets * next_sigma + + return prev_sample + + def __len__(self): + return self.config.num_train_timesteps diff --git a/src/diffusers/schedulers/scheduling_karras_ve.py b/src/diffusers/schedulers/scheduling_karras_ve.py index 743f2e061c53..b2eb332aed0f 100644 --- a/src/diffusers/schedulers/scheduling_karras_ve.py +++ b/src/diffusers/schedulers/scheduling_karras_ve.py @@ -56,8 +56,8 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin): [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and - [`~ConfigMixin.from_config`] functions. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the diff --git a/src/diffusers/schedulers/scheduling_karras_ve_flax.py b/src/diffusers/schedulers/scheduling_karras_ve_flax.py index caf27aa4c226..c4e612c3cc84 100644 --- a/src/diffusers/schedulers/scheduling_karras_ve_flax.py +++ b/src/diffusers/schedulers/scheduling_karras_ve_flax.py @@ -67,8 +67,8 @@ class FlaxKarrasVeScheduler(FlaxSchedulerMixin, ConfigMixin): [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and - [`~ConfigMixin.from_config`] functions. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the @@ -87,6 +87,10 @@ class FlaxKarrasVeScheduler(FlaxSchedulerMixin, ConfigMixin): A reasonable range is [0.2, 80]. """ + @property + def has_state(self): + return True + @register_to_config def __init__( self, @@ -97,10 +101,13 @@ def __init__( s_min: float = 0.05, s_max: float = 50, ): - self.state = KarrasVeSchedulerState.create() + pass + + def create_state(self): + return KarrasVeSchedulerState.create() def set_timesteps( - self, state: KarrasVeSchedulerState, num_inference_steps: int, shape: Tuple + self, state: KarrasVeSchedulerState, num_inference_steps: int, shape: Tuple = () ) -> KarrasVeSchedulerState: """ Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference. diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 1b8ca7c5df8d..cc9e8d72566a 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -21,11 +21,12 @@ from scipy import integrate from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput, deprecate +from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput from .scheduling_utils import SchedulerMixin @dataclass +# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->LMSDiscrete class LMSDiscreteSchedulerOutput(BaseOutput): """ Output class for the scheduler's step function output. @@ -51,8 +52,8 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and - [`~ConfigMixin.from_config`] functions. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. @@ -66,6 +67,8 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): """ + _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + @register_to_config def __init__( self, @@ -163,8 +166,13 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) + self.sigmas = torch.from_numpy(sigmas).to(device=device) - self.timesteps = torch.from_numpy(timesteps).to(device=device) + if str(device).startswith("mps"): + # mps does not support float64 + self.timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32) + else: + self.timesteps = torch.from_numpy(timesteps).to(device=device) self.derivatives = [] @@ -202,22 +210,7 @@ def step( if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) - if ( - isinstance(timestep, int) - or isinstance(timestep, torch.IntTensor) - or isinstance(timestep, torch.LongTensor) - ): - deprecate( - "timestep as an index", - "0.7.0", - "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" - " `LMSDiscreteScheduler.step()` will not be supported in future versions. Make sure to pass" - " one of the `scheduler.timesteps` as a timestep.", - standard_warn=False, - ) - step_index = timestep - else: - step_index = (self.timesteps == timestep).nonzero().item() + step_index = (self.timesteps == timestep).nonzero().item() sigma = self.sigmas[step_index] # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise @@ -250,26 +243,18 @@ def add_noise( timesteps: torch.FloatTensor, ) -> torch.FloatTensor: # Make sure sigmas and timesteps have the same device and dtype as original_samples - self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) - self.timesteps = self.timesteps.to(original_samples.device) - timesteps = timesteps.to(original_samples.device) - - schedule_timesteps = self.timesteps - - if isinstance(timesteps, torch.IntTensor) or isinstance(timesteps, torch.LongTensor): - deprecate( - "timesteps as indices", - "0.7.0", - "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" - " `LMSDiscreteScheduler.add_noise()` will not be supported in future versions. Make sure to" - " pass values from `scheduler.timesteps` as timesteps.", - standard_warn=False, - ) - step_indices = timesteps + sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) else: - step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] - sigma = self.sigmas[step_indices].flatten() + sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape): sigma = sigma.unsqueeze(-1) diff --git a/src/diffusers/schedulers/scheduling_lms_discrete_flax.py b/src/diffusers/schedulers/scheduling_lms_discrete_flax.py index cd71e1960e8c..21f25f72facf 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete_flax.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete_flax.py @@ -20,7 +20,12 @@ from scipy import integrate from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput +from .scheduling_utils_flax import ( + _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, + FlaxSchedulerMixin, + FlaxSchedulerOutput, + broadcast_to_shape_from_left, +) @flax.struct.dataclass @@ -49,8 +54,8 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin): [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and - [`~ConfigMixin.from_config`] functions. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. @@ -63,6 +68,12 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. """ + _compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + + @property + def has_state(self): + return True + @register_to_config def __init__( self, @@ -85,8 +96,10 @@ def __init__( self.alphas = 1.0 - self.betas self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0) + def create_state(self): self.state = LMSDiscreteSchedulerState.create( - num_train_timesteps=num_train_timesteps, sigmas=((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 + num_train_timesteps=self.config.num_train_timesteps, + sigmas=((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5, ) def get_lms_coefficient(self, state, order, t, current_order): @@ -112,7 +125,7 @@ def lms_derivative(tau): return integrated_coeff def set_timesteps( - self, state: LMSDiscreteSchedulerState, num_inference_steps: int, shape: Tuple + self, state: LMSDiscreteSchedulerState, num_inference_steps: int, shape: Tuple = () ) -> LMSDiscreteSchedulerState: """ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. @@ -199,8 +212,7 @@ def add_noise( timesteps: jnp.ndarray, ) -> jnp.ndarray: sigma = state.sigmas[timesteps].flatten() - while len(sigma.shape) < len(noise.shape): - sigma = sigma[..., None] + sigma = broadcast_to_shape_from_left(sigma, noise.shape) noisy_samples = original_samples + noise * sigma diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index 99ccc6c66f20..8bf0a5958266 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -21,7 +21,7 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import deprecate +from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS from .scheduling_utils import SchedulerMixin, SchedulerOutput @@ -61,8 +61,8 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and - [`~ConfigMixin.from_config`] functions. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. For more details, see the original paper: https://arxiv.org/abs/2202.09778 @@ -89,6 +89,8 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): """ + _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + @register_to_config def __init__( self, @@ -142,7 +144,7 @@ def __init__( self.plms_timesteps = None self.timesteps = None - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, **kwargs): + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. @@ -150,17 +152,13 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. """ - deprecated_offset = deprecate( - "offset", "0.7.0", "Please pass `steps_offset` to `__init__` instead.", take_from=kwargs - ) - offset = deprecated_offset or self.config.steps_offset self.num_inference_steps = num_inference_steps step_ratio = self.config.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 self._timesteps = (np.arange(0, num_inference_steps) * step_ratio).round() - self._timesteps += offset + self._timesteps += self.config.steps_offset if self.config.skip_prk_steps: # for some models like stable diffusion the prk steps can/should be skipped to @@ -310,6 +308,7 @@ def step_plms( prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps if self.counter != 1: + self.ets = self.ets[-3:] self.ets.append(model_output) else: prev_timestep = timestep diff --git a/src/diffusers/schedulers/scheduling_pndm_flax.py b/src/diffusers/schedulers/scheduling_pndm_flax.py index 2e587d5e2121..298e62de20d1 100644 --- a/src/diffusers/schedulers/scheduling_pndm_flax.py +++ b/src/diffusers/schedulers/scheduling_pndm_flax.py @@ -23,7 +23,12 @@ import jax.numpy as jnp from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput +from .scheduling_utils_flax import ( + _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, + FlaxSchedulerMixin, + FlaxSchedulerOutput, + broadcast_to_shape_from_left, +) def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999) -> jnp.ndarray: @@ -87,8 +92,8 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin): [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and - [`~ConfigMixin.from_config`] functions. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. For more details, see the original paper: https://arxiv.org/abs/2202.09778 @@ -114,6 +119,8 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin): stable diffusion. """ + _compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + @property def has_state(self): return True @@ -168,6 +175,8 @@ def set_timesteps(self, state: PNDMSchedulerState, num_inference_steps: int, sha the `FlaxPNDMScheduler` state data class instance. num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. + shape (`Tuple`): + the shape of the samples to be generated. """ offset = self.config.steps_offset @@ -509,13 +518,11 @@ def add_noise( ) -> jnp.ndarray: sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() - while len(sqrt_alpha_prod.shape) < len(original_samples.shape): - sqrt_alpha_prod = sqrt_alpha_prod[..., None] + sqrt_alpha_prod = broadcast_to_shape_from_left(sqrt_alpha_prod, original_samples.shape) sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() - while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod[..., None] + sqrt_one_minus_alpha_prod = broadcast_to_shape_from_left(sqrt_one_minus_alpha_prod, original_samples.shape) noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples diff --git a/src/diffusers/schedulers/scheduling_repaint.py b/src/diffusers/schedulers/scheduling_repaint.py new file mode 100644 index 000000000000..55625c1bfa92 --- /dev/null +++ b/src/diffusers/schedulers/scheduling_repaint.py @@ -0,0 +1,322 @@ +# Copyright 2022 ETH Zurich Computer Vision Lab and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput +from .scheduling_utils import SchedulerMixin + + +@dataclass +class RePaintSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's step function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample (x_{0}) based on the model output from + the current timestep. `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: torch.FloatTensor + pred_original_sample: torch.FloatTensor + + +def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + + def alpha_bar(time_step): + return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +class RePaintScheduler(SchedulerMixin, ConfigMixin): + """ + RePaint is a schedule for DDPM inpainting inside a given mask. + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. + + For more details, see the original paper: https://arxiv.org/pdf/2201.09865.pdf + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + beta_start (`float`): the starting `beta` value of inference. + beta_end (`float`): the final `beta` value. + beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + eta (`float`): + The weight of noise for added noise in a diffusion step. Its value is between 0.0 and 1.0 -0.0 is DDIM and + 1.0 is DDPM scheduler respectively. + trained_betas (`np.ndarray`, optional): + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + variance_type (`str`): + options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, + `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. + clip_sample (`bool`, default `True`): + option to clip predicted sample between -1 and 1 for numerical stability. + + """ + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + eta: float = 0.0, + trained_betas: Optional[np.ndarray] = None, + clip_sample: bool = True, + ): + if trained_betas is not None: + self.betas = torch.from_numpy(trained_betas) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = ( + torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + ) + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + elif beta_schedule == "sigmoid": + # GeoDiff sigmoid schedule + betas = torch.linspace(-6, 6, num_train_timesteps) + self.betas = torch.sigmoid(betas) * (beta_end - beta_start) + beta_start + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + self.one = torch.tensor(1.0) + + self.final_alpha_cumprod = torch.tensor(1.0) + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # setable values + self.num_inference_steps = None + self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy()) + + self.eta = eta + + def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.FloatTensor`): input sample + timestep (`int`, optional): current timestep + + Returns: + `torch.FloatTensor`: scaled input sample + """ + return sample + + def set_timesteps( + self, + num_inference_steps: int, + jump_length: int = 10, + jump_n_sample: int = 10, + device: Union[str, torch.device] = None, + ): + num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps) + self.num_inference_steps = num_inference_steps + + timesteps = [] + + jumps = {} + for j in range(0, num_inference_steps - jump_length, jump_length): + jumps[j] = jump_n_sample - 1 + + t = num_inference_steps + while t >= 1: + t = t - 1 + timesteps.append(t) + + if jumps.get(t, 0) > 0: + jumps[t] = jumps[t] - 1 + for _ in range(jump_length): + t = t + 1 + timesteps.append(t) + + timesteps = np.array(timesteps) * (self.config.num_train_timesteps // self.num_inference_steps) + self.timesteps = torch.from_numpy(timesteps).to(device) + + def _get_variance(self, t): + prev_timestep = t - self.config.num_train_timesteps // self.num_inference_steps + + alpha_prod_t = self.alphas_cumprod[t] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + # For t > 0, compute predicted variance βt (see formula (6) and (7) from + # https://arxiv.org/pdf/2006.11239.pdf) and sample from it to get + # previous sample x_{t-1} ~ N(pred_prev_sample, variance) == add + # variance to pred_sample + # Is equivalent to formula (16) in https://arxiv.org/pdf/2010.02502.pdf + # without eta. + # variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[t] + variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) + + return variance + + def step( + self, + model_output: torch.FloatTensor, + timestep: int, + sample: torch.FloatTensor, + original_image: torch.FloatTensor, + mask: torch.FloatTensor, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[RePaintSchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): direct output from learned + diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + original_image (`torch.FloatTensor`): + the original image to inpaint on. + mask (`torch.FloatTensor`): + the mask where 0.0 values define which part of the original image to inpaint (change). + generator (`torch.Generator`, *optional*): random number generator. + return_dict (`bool`): option for returning tuple rather than + DDPMSchedulerOutput class + + Returns: + [`~schedulers.scheduling_utils.RePaintSchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.RePaintSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + + """ + t = timestep + prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps + + # 1. compute alphas, betas + alpha_prod_t = self.alphas_cumprod[t] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + beta_prod_t = 1 - alpha_prod_t + + # 2. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf + pred_original_sample = (sample - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5 + + # 3. Clip "predicted x_0" + if self.config.clip_sample: + pred_original_sample = torch.clamp(pred_original_sample, -1, 1) + + # We choose to follow RePaint Algorithm 1 to get x_{t-1}, however we + # substitute formula (7) in the algorithm coming from DDPM paper + # (formula (4) Algorithm 2 - Sampling) with formula (12) from DDIM paper. + # DDIM schedule gives the same results as DDPM with eta = 1.0 + # Noise is being reused in 7. and 8., but no impact on quality has + # been observed. + + # 5. Add noise + noise = torch.randn( + model_output.shape, dtype=model_output.dtype, generator=generator, device=model_output.device + ) + std_dev_t = self.eta * self._get_variance(timestep) ** 0.5 + + variance = 0 + if t > 0 and self.eta > 0: + variance = std_dev_t * noise + + # 6. compute "direction pointing to x_t" of formula (12) + # from https://arxiv.org/pdf/2010.02502.pdf + pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** 0.5 * model_output + + # 7. compute x_{t-1} of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + prev_unknown_part = alpha_prod_t_prev**0.5 * pred_original_sample + pred_sample_direction + variance + + # 8. Algorithm 1 Line 5 https://arxiv.org/pdf/2201.09865.pdf + prev_known_part = (alpha_prod_t**0.5) * original_image + ((1 - alpha_prod_t) ** 0.5) * noise + + # 9. Algorithm 1 Line 8 https://arxiv.org/pdf/2201.09865.pdf + pred_prev_sample = mask * prev_known_part + (1.0 - mask) * prev_unknown_part + + if not return_dict: + return ( + pred_prev_sample, + pred_original_sample, + ) + + return RePaintSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample) + + def undo_step(self, sample, timestep, generator=None): + n = self.config.num_train_timesteps // self.num_inference_steps + + for i in range(n): + beta = self.betas[timestep + i] + noise = torch.randn(sample.shape, generator=generator, device=sample.device) + + # 10. Algorithm 1 Line 10 https://arxiv.org/pdf/2201.09865.pdf + sample = (1 - beta) ** 0.5 * sample + beta**0.5 * noise + + return sample + + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor, + ) -> torch.FloatTensor: + raise NotImplementedError("Use `DDPMScheduler.add_noise()` to train for sampling with RePaint.") + + def __len__(self): + return self.config.num_train_timesteps diff --git a/src/diffusers/schedulers/scheduling_sde_ve.py b/src/diffusers/schedulers/scheduling_sde_ve.py index d31adbc3c673..1d436ab0cbbf 100644 --- a/src/diffusers/schedulers/scheduling_sde_ve.py +++ b/src/diffusers/schedulers/scheduling_sde_ve.py @@ -50,8 +50,8 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and - [`~ConfigMixin.from_config`] functions. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. diff --git a/src/diffusers/schedulers/scheduling_sde_ve_flax.py b/src/diffusers/schedulers/scheduling_sde_ve_flax.py index c4d802f83f94..d1f762bc90c4 100644 --- a/src/diffusers/schedulers/scheduling_sde_ve_flax.py +++ b/src/diffusers/schedulers/scheduling_sde_ve_flax.py @@ -22,7 +22,7 @@ from jax import random from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput +from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left @flax.struct.dataclass @@ -64,8 +64,8 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin): [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and - [`~ConfigMixin.from_config`] functions. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. @@ -80,6 +80,10 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin): correct_steps (`int`): number of correction steps performed on a produced sample. """ + @property + def has_state(self): + return True + @register_to_config def __init__( self, @@ -90,12 +94,20 @@ def __init__( sampling_eps: float = 1e-5, correct_steps: int = 1, ): - state = ScoreSdeVeSchedulerState.create() + pass - self.state = self.set_sigmas(state, num_train_timesteps, sigma_min, sigma_max, sampling_eps) + def create_state(self): + state = ScoreSdeVeSchedulerState.create() + return self.set_sigmas( + state, + self.config.num_train_timesteps, + self.config.sigma_min, + self.config.sigma_max, + self.config.sampling_eps, + ) def set_timesteps( - self, state: ScoreSdeVeSchedulerState, num_inference_steps: int, shape: Tuple, sampling_eps: float = None + self, state: ScoreSdeVeSchedulerState, num_inference_steps: int, shape: Tuple = (), sampling_eps: float = None ) -> ScoreSdeVeSchedulerState: """ Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference. @@ -193,8 +205,7 @@ def step_pred( # equation 6 in the paper: the model_output modeled by the network is grad_x log pt(x) # also equation 47 shows the analog from SDE models to ancestral sampling methods diffusion = diffusion.flatten() - while len(diffusion.shape) < len(sample.shape): - diffusion = diffusion[:, None] + diffusion = broadcast_to_shape_from_left(diffusion, sample.shape) drift = drift - diffusion**2 * model_output # equation 6: sample noise for the diffusion term of @@ -252,8 +263,7 @@ def step_correct( # compute corrected sample: model_output term and noise term step_size = step_size.flatten() - while len(step_size.shape) < len(sample.shape): - step_size = step_size[:, None] + step_size = broadcast_to_shape_from_left(step_size, sample.shape) prev_sample_mean = sample + step_size * model_output prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5) * noise diff --git a/src/diffusers/schedulers/scheduling_sde_vp.py b/src/diffusers/schedulers/scheduling_sde_vp.py index a37a159a87ac..537d6f7e2a29 100644 --- a/src/diffusers/schedulers/scheduling_sde_vp.py +++ b/src/diffusers/schedulers/scheduling_sde_vp.py @@ -29,8 +29,8 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin): [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and - [`~ConfigMixin.from_config`] functions. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. For more information, see the original paper: https://arxiv.org/abs/2011.13456 diff --git a/src/diffusers/schedulers/scheduling_utils.py b/src/diffusers/schedulers/scheduling_utils.py index 29bf982f6adf..90ab674e38a4 100644 --- a/src/diffusers/schedulers/scheduling_utils.py +++ b/src/diffusers/schedulers/scheduling_utils.py @@ -11,7 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import importlib +import os from dataclasses import dataclass +from typing import Any, Dict, Optional, Union import torch @@ -38,6 +41,114 @@ class SchedulerOutput(BaseOutput): class SchedulerMixin: """ Mixin containing common functions for the schedulers. + + Class attributes: + - **_compatibles** (`List[str]`) -- A list of classes that are compatible with the parent class, so that + `from_config` can be used from a class different than the one used to save the config (should be overridden + by parent class). """ config_name = SCHEDULER_CONFIG_NAME + _compatibles = [] + has_compatibles = True + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Dict[str, Any] = None, + subfolder: Optional[str] = None, + return_unused_kwargs=False, + **kwargs, + ): + r""" + Instantiate a Scheduler class from a pre-defined JSON configuration file inside a directory or Hub repo. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an + organization name, like `google/ddpm-celebahq-256`. + - A path to a *directory* containing the schedluer configurations saved using + [`~SchedulerMixin.save_pretrained`], e.g., `./my_model_directory/`. + subfolder (`str`, *optional*): + In case the relevant files are located inside a subfolder of the model repo (either remote in + huggingface.co or downloaded locally), you can specify the folder name here. + return_unused_kwargs (`bool`, *optional*, defaults to `False`): + Whether kwargs that are not consumed by the Python class should be returned or not. + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to delete incompletely received files. Will attempt to resume the download if such a + file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `transformers-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + + + + It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated + models](https://huggingface.co/docs/hub/models-gated#gated-models). + + + + + + Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to + use this method in a firewalled environment. + + + + """ + config, kwargs = cls.load_config( + pretrained_model_name_or_path=pretrained_model_name_or_path, + subfolder=subfolder, + return_unused_kwargs=True, + **kwargs, + ) + return cls.from_config(config, return_unused_kwargs=return_unused_kwargs, **kwargs) + + def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): + """ + Save a scheduler configuration object to the directory `save_directory`, so that it can be re-loaded using the + [`~SchedulerMixin.from_pretrained`] class method. + + Args: + save_directory (`str` or `os.PathLike`): + Directory where the configuration JSON file will be saved (will be created if it does not exist). + """ + self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) + + @property + def compatibles(self): + """ + Returns all schedulers that are compatible with this scheduler + + Returns: + `List[SchedulerMixin]`: List of compatible schedulers + """ + return self._get_compatibles() + + @classmethod + def _get_compatibles(cls): + compatible_classes_str = list(set([cls.__name__] + cls._compatibles)) + diffusers_library = importlib.import_module(__name__.split(".")[0]) + compatible_classes = [ + getattr(diffusers_library, c) for c in compatible_classes_str if hasattr(diffusers_library, c) + ] + return compatible_classes diff --git a/src/diffusers/schedulers/scheduling_utils_flax.py b/src/diffusers/schedulers/scheduling_utils_flax.py index 63de51f146f5..5dc28c25d9d6 100644 --- a/src/diffusers/schedulers/scheduling_utils_flax.py +++ b/src/diffusers/schedulers/scheduling_utils_flax.py @@ -11,14 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import importlib +import os from dataclasses import dataclass +from typing import Any, Dict, Optional, Tuple, Union import jax.numpy as jnp -from ..utils import BaseOutput +from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput SCHEDULER_CONFIG_NAME = "scheduler_config.json" +_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS = ["Flax" + c for c in _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS] @dataclass @@ -38,6 +42,128 @@ class FlaxSchedulerOutput(BaseOutput): class FlaxSchedulerMixin: """ Mixin containing common functions for the schedulers. + + Class attributes: + - **_compatibles** (`List[str]`) -- A list of classes that are compatible with the parent class, so that + `from_config` can be used from a class different than the one used to save the config (should be overridden + by parent class). """ config_name = SCHEDULER_CONFIG_NAME + _compatibles = [] + has_compatibles = True + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Dict[str, Any] = None, + subfolder: Optional[str] = None, + return_unused_kwargs=False, + **kwargs, + ): + r""" + Instantiate a Scheduler class from a pre-defined JSON-file. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an + organization name, like `google/ddpm-celebahq-256`. + - A path to a *directory* containing model weights saved using [`~SchedulerMixin.save_pretrained`], + e.g., `./my_model_directory/`. + subfolder (`str`, *optional*): + In case the relevant files are located inside a subfolder of the model repo (either remote in + huggingface.co or downloaded locally), you can specify the folder name here. + return_unused_kwargs (`bool`, *optional*, defaults to `False`): + Whether kwargs that are not consumed by the Python class should be returned or not. + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to delete incompletely received files. Will attempt to resume the download if such a + file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `transformers-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + + + + It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated + models](https://huggingface.co/docs/hub/models-gated#gated-models). + + + + + + Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to + use this method in a firewalled environment. + + + + """ + config, kwargs = cls.load_config( + pretrained_model_name_or_path=pretrained_model_name_or_path, + subfolder=subfolder, + return_unused_kwargs=True, + **kwargs, + ) + scheduler, unused_kwargs = cls.from_config(config, return_unused_kwargs=True, **kwargs) + + if hasattr(scheduler, "create_state") and getattr(scheduler, "has_state", False): + state = scheduler.create_state() + + if return_unused_kwargs: + return scheduler, state, unused_kwargs + + return scheduler, state + + def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): + """ + Save a scheduler configuration object to the directory `save_directory`, so that it can be re-loaded using the + [`~FlaxSchedulerMixin.from_pretrained`] class method. + + Args: + save_directory (`str` or `os.PathLike`): + Directory where the configuration JSON file will be saved (will be created if it does not exist). + """ + self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) + + @property + def compatibles(self): + """ + Returns all schedulers that are compatible with this scheduler + + Returns: + `List[SchedulerMixin]`: List of compatible schedulers + """ + return self._get_compatibles() + + @classmethod + def _get_compatibles(cls): + compatible_classes_str = list(set([cls.__name__] + cls._compatibles)) + diffusers_library = importlib.import_module(__name__.split(".")[0]) + compatible_classes = [ + getattr(diffusers_library, c) for c in compatible_classes_str if hasattr(diffusers_library, c) + ] + return compatible_classes + + +def broadcast_to_shape_from_left(x: jnp.ndarray, shape: Tuple[int]) -> jnp.ndarray: + assert len(shape) >= x.ndim + return jnp.broadcast_to(x.reshape(x.shape + (1,) * (len(shape) - x.ndim)), shape) diff --git a/src/diffusers/schedulers/scheduling_vq_diffusion.py b/src/diffusers/schedulers/scheduling_vq_diffusion.py new file mode 100644 index 000000000000..91c46e655451 --- /dev/null +++ b/src/diffusers/schedulers/scheduling_vq_diffusion.py @@ -0,0 +1,494 @@ +# Copyright 2022 Microsoft and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput +from .scheduling_utils import SchedulerMixin + + +@dataclass +class VQDiffusionSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's step function output. + + Args: + prev_sample (`torch.LongTensor` of shape `(batch size, num latent pixels)`): + Computed sample x_{t-1} of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + """ + + prev_sample: torch.LongTensor + + +def index_to_log_onehot(x: torch.LongTensor, num_classes: int) -> torch.FloatTensor: + """ + Convert batch of vector of class indices into batch of log onehot vectors + + Args: + x (`torch.LongTensor` of shape `(batch size, vector length)`): + Batch of class indices + + num_classes (`int`): + number of classes to be used for the onehot vectors + + Returns: + `torch.FloatTensor` of shape `(batch size, num classes, vector length)`: + Log onehot vectors + """ + x_onehot = F.one_hot(x, num_classes) + x_onehot = x_onehot.permute(0, 2, 1) + log_x = torch.log(x_onehot.float().clamp(min=1e-30)) + return log_x + + +def gumbel_noised(logits: torch.FloatTensor, generator: Optional[torch.Generator]) -> torch.FloatTensor: + """ + Apply gumbel noise to `logits` + """ + uniform = torch.rand(logits.shape, device=logits.device, generator=generator) + gumbel_noise = -torch.log(-torch.log(uniform + 1e-30) + 1e-30) + noised = gumbel_noise + logits + return noised + + +def alpha_schedules(num_diffusion_timesteps: int, alpha_cum_start=0.99999, alpha_cum_end=0.000009): + """ + Cumulative and non-cumulative alpha schedules. + + See section 4.1. + """ + att = ( + np.arange(0, num_diffusion_timesteps) / (num_diffusion_timesteps - 1) * (alpha_cum_end - alpha_cum_start) + + alpha_cum_start + ) + att = np.concatenate(([1], att)) + at = att[1:] / att[:-1] + att = np.concatenate((att[1:], [1])) + return at, att + + +def gamma_schedules(num_diffusion_timesteps: int, gamma_cum_start=0.000009, gamma_cum_end=0.99999): + """ + Cumulative and non-cumulative gamma schedules. + + See section 4.1. + """ + ctt = ( + np.arange(0, num_diffusion_timesteps) / (num_diffusion_timesteps - 1) * (gamma_cum_end - gamma_cum_start) + + gamma_cum_start + ) + ctt = np.concatenate(([0], ctt)) + one_minus_ctt = 1 - ctt + one_minus_ct = one_minus_ctt[1:] / one_minus_ctt[:-1] + ct = 1 - one_minus_ct + ctt = np.concatenate((ctt[1:], [0])) + return ct, ctt + + +class VQDiffusionScheduler(SchedulerMixin, ConfigMixin): + """ + The VQ-diffusion transformer outputs predicted probabilities of the initial unnoised image. + + The VQ-diffusion scheduler converts the transformer's output into a sample for the unnoised image at the previous + diffusion timestep. + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. + + For more details, see the original paper: https://arxiv.org/abs/2111.14822 + + Args: + num_vec_classes (`int`): + The number of classes of the vector embeddings of the latent pixels. Includes the class for the masked + latent pixel. + + num_train_timesteps (`int`): + Number of diffusion steps used to train the model. + + alpha_cum_start (`float`): + The starting cumulative alpha value. + + alpha_cum_end (`float`): + The ending cumulative alpha value. + + gamma_cum_start (`float`): + The starting cumulative gamma value. + + gamma_cum_end (`float`): + The ending cumulative gamma value. + """ + + @register_to_config + def __init__( + self, + num_vec_classes: int, + num_train_timesteps: int = 100, + alpha_cum_start: float = 0.99999, + alpha_cum_end: float = 0.000009, + gamma_cum_start: float = 0.000009, + gamma_cum_end: float = 0.99999, + ): + self.num_embed = num_vec_classes + + # By convention, the index for the mask class is the last class index + self.mask_class = self.num_embed - 1 + + at, att = alpha_schedules(num_train_timesteps, alpha_cum_start=alpha_cum_start, alpha_cum_end=alpha_cum_end) + ct, ctt = gamma_schedules(num_train_timesteps, gamma_cum_start=gamma_cum_start, gamma_cum_end=gamma_cum_end) + + num_non_mask_classes = self.num_embed - 1 + bt = (1 - at - ct) / num_non_mask_classes + btt = (1 - att - ctt) / num_non_mask_classes + + at = torch.tensor(at.astype("float64")) + bt = torch.tensor(bt.astype("float64")) + ct = torch.tensor(ct.astype("float64")) + log_at = torch.log(at) + log_bt = torch.log(bt) + log_ct = torch.log(ct) + + att = torch.tensor(att.astype("float64")) + btt = torch.tensor(btt.astype("float64")) + ctt = torch.tensor(ctt.astype("float64")) + log_cumprod_at = torch.log(att) + log_cumprod_bt = torch.log(btt) + log_cumprod_ct = torch.log(ctt) + + self.log_at = log_at.float() + self.log_bt = log_bt.float() + self.log_ct = log_ct.float() + self.log_cumprod_at = log_cumprod_at.float() + self.log_cumprod_bt = log_cumprod_bt.float() + self.log_cumprod_ct = log_cumprod_ct.float() + + # setable values + self.num_inference_steps = None + self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy()) + + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + """ + Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + + device (`str` or `torch.device`): + device to place the timesteps and the diffusion process parameters (alpha, beta, gamma) on. + """ + self.num_inference_steps = num_inference_steps + timesteps = np.arange(0, self.num_inference_steps)[::-1].copy() + self.timesteps = torch.from_numpy(timesteps).to(device) + + self.log_at = self.log_at.to(device) + self.log_bt = self.log_bt.to(device) + self.log_ct = self.log_ct.to(device) + self.log_cumprod_at = self.log_cumprod_at.to(device) + self.log_cumprod_bt = self.log_cumprod_bt.to(device) + self.log_cumprod_ct = self.log_cumprod_ct.to(device) + + def step( + self, + model_output: torch.FloatTensor, + timestep: torch.long, + sample: torch.LongTensor, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[VQDiffusionSchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep via the reverse transition distribution i.e. Equation (11). See the + docstring for `self.q_posterior` for more in depth docs on how Equation (11) is computed. + + Args: + log_p_x_0: (`torch.FloatTensor` of shape `(batch size, num classes - 1, num latent pixels)`): + The log probabilities for the predicted classes of the initial latent pixels. Does not include a + prediction for the masked class as the initial unnoised image cannot be masked. + + t (`torch.long`): + The timestep that determines which transition matrices are used. + + x_t: (`torch.LongTensor` of shape `(batch size, num latent pixels)`): + The classes of each latent pixel at time `t` + + generator: (`torch.Generator` or None): + RNG for the noise applied to p(x_{t-1} | x_t) before it is sampled from. + + return_dict (`bool`): + option for returning tuple rather than VQDiffusionSchedulerOutput class + + Returns: + [`~schedulers.scheduling_utils.VQDiffusionSchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.VQDiffusionSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. + When returning a tuple, the first element is the sample tensor. + """ + if timestep == 0: + log_p_x_t_min_1 = model_output + else: + log_p_x_t_min_1 = self.q_posterior(model_output, sample, timestep) + + log_p_x_t_min_1 = gumbel_noised(log_p_x_t_min_1, generator) + + x_t_min_1 = log_p_x_t_min_1.argmax(dim=1) + + if not return_dict: + return (x_t_min_1,) + + return VQDiffusionSchedulerOutput(prev_sample=x_t_min_1) + + def q_posterior(self, log_p_x_0, x_t, t): + """ + Calculates the log probabilities for the predicted classes of the image at timestep `t-1`. I.e. Equation (11). + + Instead of directly computing equation (11), we use Equation (5) to restate Equation (11) in terms of only + forward probabilities. + + Equation (11) stated in terms of forward probabilities via Equation (5): + + Where: + - the sum is over x_0 = {C_0 ... C_{k-1}} (classes for x_0) + + p(x_{t-1} | x_t) = sum( q(x_t | x_{t-1}) * q(x_{t-1} | x_0) * p(x_0) / q(x_t | x_0) ) + + Args: + log_p_x_0: (`torch.FloatTensor` of shape `(batch size, num classes - 1, num latent pixels)`): + The log probabilities for the predicted classes of the initial latent pixels. Does not include a + prediction for the masked class as the initial unnoised image cannot be masked. + + x_t: (`torch.LongTensor` of shape `(batch size, num latent pixels)`): + The classes of each latent pixel at time `t` + + t (torch.Long): + The timestep that determines which transition matrix is used. + + Returns: + `torch.FloatTensor` of shape `(batch size, num classes, num latent pixels)`: + The log probabilities for the predicted classes of the image at timestep `t-1`. I.e. Equation (11). + """ + log_onehot_x_t = index_to_log_onehot(x_t, self.num_embed) + + log_q_x_t_given_x_0 = self.log_Q_t_transitioning_to_known_class( + t=t, x_t=x_t, log_onehot_x_t=log_onehot_x_t, cumulative=True + ) + + log_q_t_given_x_t_min_1 = self.log_Q_t_transitioning_to_known_class( + t=t, x_t=x_t, log_onehot_x_t=log_onehot_x_t, cumulative=False + ) + + # p_0(x_0=C_0 | x_t) / q(x_t | x_0=C_0) ... p_n(x_0=C_0 | x_t) / q(x_t | x_0=C_0) + # . . . + # . . . + # . . . + # p_0(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) ... p_n(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) + q = log_p_x_0 - log_q_x_t_given_x_0 + + # sum_0 = p_0(x_0=C_0 | x_t) / q(x_t | x_0=C_0) + ... + p_0(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}), ... , + # sum_n = p_n(x_0=C_0 | x_t) / q(x_t | x_0=C_0) + ... + p_n(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) + q_log_sum_exp = torch.logsumexp(q, dim=1, keepdim=True) + + # p_0(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_0 ... p_n(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_n + # . . . + # . . . + # . . . + # p_0(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_0 ... p_n(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_n + q = q - q_log_sum_exp + + # (p_0(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_0) * a_cumulative_{t-1} + b_cumulative_{t-1} ... (p_n(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_n) * a_cumulative_{t-1} + b_cumulative_{t-1} + # . . . + # . . . + # . . . + # (p_0(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_0) * a_cumulative_{t-1} + b_cumulative_{t-1} ... (p_n(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_n) * a_cumulative_{t-1} + b_cumulative_{t-1} + # c_cumulative_{t-1} ... c_cumulative_{t-1} + q = self.apply_cumulative_transitions(q, t - 1) + + # ((p_0(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_0) * a_cumulative_{t-1} + b_cumulative_{t-1}) * q(x_t | x_{t-1}=C_0) * sum_0 ... ((p_n(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_n) * a_cumulative_{t-1} + b_cumulative_{t-1}) * q(x_t | x_{t-1}=C_0) * sum_n + # . . . + # . . . + # . . . + # ((p_0(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_0) * a_cumulative_{t-1} + b_cumulative_{t-1}) * q(x_t | x_{t-1}=C_{k-1}) * sum_0 ... ((p_n(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_n) * a_cumulative_{t-1} + b_cumulative_{t-1}) * q(x_t | x_{t-1}=C_{k-1}) * sum_n + # c_cumulative_{t-1} * q(x_t | x_{t-1}=C_k) * sum_0 ... c_cumulative_{t-1} * q(x_t | x_{t-1}=C_k) * sum_0 + log_p_x_t_min_1 = q + log_q_t_given_x_t_min_1 + q_log_sum_exp + + # For each column, there are two possible cases. + # + # Where: + # - sum(p_n(x_0))) is summing over all classes for x_0 + # - C_i is the class transitioning from (not to be confused with c_t and c_cumulative_t being used for gamma's) + # - C_j is the class transitioning to + # + # 1. x_t is masked i.e. x_t = c_k + # + # Simplifying the expression, the column vector is: + # . + # . + # . + # (c_t / c_cumulative_t) * (a_cumulative_{t-1} * p_n(x_0 = C_i | x_t) + b_cumulative_{t-1} * sum(p_n(x_0))) + # . + # . + # . + # (c_cumulative_{t-1} / c_cumulative_t) * sum(p_n(x_0)) + # + # From equation (11) stated in terms of forward probabilities, the last row is trivially verified. + # + # For the other rows, we can state the equation as ... + # + # (c_t / c_cumulative_t) * [b_cumulative_{t-1} * p(x_0=c_0) + ... + (a_cumulative_{t-1} + b_cumulative_{t-1}) * p(x_0=C_i) + ... + b_cumulative_{k-1} * p(x_0=c_{k-1})] + # + # This verifies the other rows. + # + # 2. x_t is not masked + # + # Simplifying the expression, there are two cases for the rows of the column vector, where C_j = C_i and where C_j != C_i: + # . + # . + # . + # C_j != C_i: b_t * ((b_cumulative_{t-1} / b_cumulative_t) * p_n(x_0 = c_0) + ... + ((a_cumulative_{t-1} + b_cumulative_{t-1}) / b_cumulative_t) * p_n(x_0 = C_i) + ... + (b_cumulative_{t-1} / (a_cumulative_t + b_cumulative_t)) * p_n(c_0=C_j) + ... + (b_cumulative_{t-1} / b_cumulative_t) * p_n(x_0 = c_{k-1})) + # . + # . + # . + # C_j = C_i: (a_t + b_t) * ((b_cumulative_{t-1} / b_cumulative_t) * p_n(x_0 = c_0) + ... + ((a_cumulative_{t-1} + b_cumulative_{t-1}) / (a_cumulative_t + b_cumulative_t)) * p_n(x_0 = C_i = C_j) + ... + (b_cumulative_{t-1} / b_cumulative_t) * p_n(x_0 = c_{k-1})) + # . + # . + # . + # 0 + # + # The last row is trivially verified. The other rows can be verified by directly expanding equation (11) stated in terms of forward probabilities. + return log_p_x_t_min_1 + + def log_Q_t_transitioning_to_known_class( + self, *, t: torch.int, x_t: torch.LongTensor, log_onehot_x_t: torch.FloatTensor, cumulative: bool + ): + """ + Returns the log probabilities of the rows from the (cumulative or non-cumulative) transition matrix for each + latent pixel in `x_t`. + + See equation (7) for the complete non-cumulative transition matrix. The complete cumulative transition matrix + is the same structure except the parameters (alpha, beta, gamma) are the cumulative analogs. + + Args: + t (torch.Long): + The timestep that determines which transition matrix is used. + + x_t (`torch.LongTensor` of shape `(batch size, num latent pixels)`): + The classes of each latent pixel at time `t`. + + log_onehot_x_t (`torch.FloatTensor` of shape `(batch size, num classes, num latent pixels)`): + The log one-hot vectors of `x_t` + + cumulative (`bool`): + If cumulative is `False`, we use the single step transition matrix `t-1`->`t`. If cumulative is `True`, + we use the cumulative transition matrix `0`->`t`. + + Returns: + `torch.FloatTensor` of shape `(batch size, num classes - 1, num latent pixels)`: + Each _column_ of the returned matrix is a _row_ of log probabilities of the complete probability + transition matrix. + + When non cumulative, returns `self.num_classes - 1` rows because the initial latent pixel cannot be + masked. + + Where: + - `q_n` is the probability distribution for the forward process of the `n`th latent pixel. + - C_0 is a class of a latent pixel embedding + - C_k is the class of the masked latent pixel + + non-cumulative result (omitting logarithms): + ``` + q_0(x_t | x_{t-1} = C_0) ... q_n(x_t | x_{t-1} = C_0) + . . . + . . . + . . . + q_0(x_t | x_{t-1} = C_k) ... q_n(x_t | x_{t-1} = C_k) + ``` + + cumulative result (omitting logarithms): + ``` + q_0_cumulative(x_t | x_0 = C_0) ... q_n_cumulative(x_t | x_0 = C_0) + . . . + . . . + . . . + q_0_cumulative(x_t | x_0 = C_{k-1}) ... q_n_cumulative(x_t | x_0 = C_{k-1}) + ``` + """ + if cumulative: + a = self.log_cumprod_at[t] + b = self.log_cumprod_bt[t] + c = self.log_cumprod_ct[t] + else: + a = self.log_at[t] + b = self.log_bt[t] + c = self.log_ct[t] + + if not cumulative: + # The values in the onehot vector can also be used as the logprobs for transitioning + # from masked latent pixels. If we are not calculating the cumulative transitions, + # we need to save these vectors to be re-appended to the final matrix so the values + # aren't overwritten. + # + # `P(x_t!=mask|x_{t-1=mask}) = 0` and 0 will be the value of the last row of the onehot vector + # if x_t is not masked + # + # `P(x_t=mask|x_{t-1=mask}) = 1` and 1 will be the value of the last row of the onehot vector + # if x_t is masked + log_onehot_x_t_transitioning_from_masked = log_onehot_x_t[:, -1, :].unsqueeze(1) + + # `index_to_log_onehot` will add onehot vectors for masked pixels, + # so the default one hot matrix has one too many rows. See the doc string + # for an explanation of the dimensionality of the returned matrix. + log_onehot_x_t = log_onehot_x_t[:, :-1, :] + + # this is a cheeky trick to produce the transition probabilities using log one-hot vectors. + # + # Don't worry about what values this sets in the columns that mark transitions + # to masked latent pixels. They are overwrote later with the `mask_class_mask`. + # + # Looking at the below logspace formula in non-logspace, each value will evaluate to either + # `1 * a + b = a + b` where `log_Q_t` has the one hot value in the column + # or + # `0 * a + b = b` where `log_Q_t` has the 0 values in the column. + # + # See equation 7 for more details. + log_Q_t = (log_onehot_x_t + a).logaddexp(b) + + # The whole column of each masked pixel is `c` + mask_class_mask = x_t == self.mask_class + mask_class_mask = mask_class_mask.unsqueeze(1).expand(-1, self.num_embed - 1, -1) + log_Q_t[mask_class_mask] = c + + if not cumulative: + log_Q_t = torch.cat((log_Q_t, log_onehot_x_t_transitioning_from_masked), dim=1) + + return log_Q_t + + def apply_cumulative_transitions(self, q, t): + bsz = q.shape[0] + a = self.log_cumprod_at[t] + b = self.log_cumprod_bt[t] + c = self.log_cumprod_ct[t] + + num_latent_pixels = q.shape[2] + c = c.expand(bsz, 1, num_latent_pixels) + + q = (q + a).logaddexp(b) + q = torch.cat((q, c), dim=1) + + return q diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 51798e2ad765..e86f3b801ae4 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -31,16 +31,29 @@ is_scipy_available, is_tf_available, is_torch_available, + is_torch_version, is_transformers_available, + is_transformers_version, is_unidecode_available, requires_backends, ) from .logging import get_logger from .outputs import BaseOutput +from .pil_utils import PIL_INTERPOLATION if is_torch_available(): - from .testing_utils import floats_tensor, load_image, parse_flag_from_env, slow, torch_device + from .testing_utils import ( + floats_tensor, + load_hf_numpy, + load_image, + load_numpy, + parse_flag_from_env, + require_torch_gpu, + slow, + torch_all_close, + torch_device, + ) logger = get_logger(__name__) @@ -56,7 +69,18 @@ WEIGHTS_NAME = "diffusion_pytorch_model.bin" FLAX_WEIGHTS_NAME = "diffusion_flax_model.msgpack" ONNX_WEIGHTS_NAME = "model.onnx" +ONNX_EXTERNAL_WEIGHTS_NAME = "weights.pb" HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co" DIFFUSERS_CACHE = default_cache_path DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules" HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules")) + +_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS = [ + "DDIMScheduler", + "DDPMScheduler", + "PNDMScheduler", + "LMSDiscreteScheduler", + "EulerDiscreteScheduler", + "EulerAncestralDiscreteScheduler", + "DPMSolverMultistepScheduler", +] diff --git a/src/diffusers/utils/deprecation_utils.py b/src/diffusers/utils/deprecation_utils.py index eac43031574f..7c8bfc901b31 100644 --- a/src/diffusers/utils/deprecation_utils.py +++ b/src/diffusers/utils/deprecation_utils.py @@ -32,7 +32,7 @@ def deprecate(*args, take_from: Optional[Union[Dict, Any]] = None, standard_warn if warning is not None: warning = warning + " " if standard_warn else "" - warnings.warn(warning + message, DeprecationWarning) + warnings.warn(warning + message, FutureWarning) if isinstance(deprecated_kwargs, dict) and len(deprecated_kwargs) > 0: call_frame = inspect.getouterframes(inspect.currentframe())[1] diff --git a/src/diffusers/utils/dummy_flax_objects.py b/src/diffusers/utils/dummy_flax_objects.py index 708022d85b31..8e308bb41bea 100644 --- a/src/diffusers/utils/dummy_flax_objects.py +++ b/src/diffusers/utils/dummy_flax_objects.py @@ -94,6 +94,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["flax"]) +class FlaxDPMSolverMultistepScheduler(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + class FlaxKarrasVeScheduler(metaclass=DummyObject): _backends = ["flax"] diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index ee748f5b1d6a..af2e0c7c61d6 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -34,6 +34,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class Transformer2DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class UNet1DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class UNet2DConditionModel(metaclass=DummyObject): _backends = ["torch"] @@ -122,6 +152,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class DanceDiffusionPipeline(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class DDIMPipeline(metaclass=DummyObject): _backends = ["torch"] @@ -182,6 +227,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class LDMSuperResolutionPipeline(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class PNDMPipeline(metaclass=DummyObject): _backends = ["torch"] @@ -197,6 +257,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class RePaintPipeline(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class ScoreSdeVePipeline(metaclass=DummyObject): _backends = ["torch"] @@ -242,6 +317,66 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class DPMSolverMultistepScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class EulerAncestralDiscreteScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class EulerDiscreteScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class IPNDMScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class KarrasVeScheduler(metaclass=DummyObject): _backends = ["torch"] @@ -272,6 +407,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class RePaintScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class SchedulerMixin(metaclass=DummyObject): _backends = ["torch"] @@ -302,6 +452,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class VQDiffusionScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class EMAModel(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_and_onnx_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_and_onnx_objects.py index d099b837296a..ae9412a95682 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_and_onnx_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_and_onnx_objects.py @@ -4,6 +4,66 @@ from ..utils import DummyObject, requires_backends +class OnnxStableDiffusionImg2ImgPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers", "onnx"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers", "onnx"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers", "onnx"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers", "onnx"]) + + +class OnnxStableDiffusionInpaintPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers", "onnx"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers", "onnx"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers", "onnx"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers", "onnx"]) + + +class OnnxStableDiffusionInpaintPipelineLegacy(metaclass=DummyObject): + _backends = ["torch", "transformers", "onnx"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers", "onnx"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers", "onnx"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers", "onnx"]) + + +class OnnxStableDiffusionPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers", "onnx"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers", "onnx"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers", "onnx"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers", "onnx"]) + + class StableDiffusionOnnxPipeline(metaclass=DummyObject): _backends = ["torch", "transformers", "onnx"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index ce9a02bca150..2d932d240508 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -4,6 +4,51 @@ from ..utils import DummyObject, requires_backends +class AltDiffusionImg2ImgPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class AltDiffusionPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class CycleDiffusionPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class LDMTextToImagePipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] @@ -19,6 +64,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class StableDiffusionImageVariationPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class StableDiffusionImg2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] @@ -49,6 +109,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class StableDiffusionInpaintPipelineLegacy(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class StableDiffusionPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] @@ -62,3 +137,108 @@ def from_config(cls, *args, **kwargs): @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusionPipelineSafe(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusionUpscalePipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class VersatileDiffusionDualGuidedPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class VersatileDiffusionImageVariationPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class VersatileDiffusionPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class VersatileDiffusionTextToImagePipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class VQDiffusionPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index b2aabee70c92..c0294b4a3d23 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -15,11 +15,14 @@ Import utilities: Utilities related to imports and our lazy inits. """ import importlib.util +import operator as op import os import sys from collections import OrderedDict +from typing import Union from packaging import version +from packaging.version import Version, parse from . import logging @@ -40,6 +43,8 @@ USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper() USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper() +STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt} + _torch_version = "N/A" if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES: _torch_available = importlib.util.find_spec("torch") is not None @@ -90,7 +95,8 @@ logger.info("Disabling Tensorflow because USE_TORCH is set") _tf_available = False - +_jax_version = "N/A" +_flax_version = "N/A" if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES: _flax_available = importlib.util.find_spec("jax") is not None and importlib.util.find_spec("flax") is not None if _flax_available: @@ -136,6 +142,7 @@ _modelcards_available = False +_onnxruntime_version = "N/A" _onnx_available = importlib.util.find_spec("onnxruntime") is not None if _onnx_available: candidates = ("onnxruntime", "onnxruntime-gpu", "onnxruntime-directml", "onnxruntime-openvino") @@ -166,6 +173,18 @@ except importlib_metadata.PackageNotFoundError: _accelerate_available = False +_xformers_available = importlib.util.find_spec("xformers") is not None +try: + _xformers_version = importlib_metadata.version("xformers") + if _torch_available: + import torch + + if torch.__version__ < version.Version("1.12"): + raise ValueError("PyTorch should be >= 1.12") + logger.debug(f"Successfully imported xformers version {_xformers_version}") +except importlib_metadata.PackageNotFoundError: + _xformers_available = False + def is_torch_available(): return _torch_available @@ -203,6 +222,10 @@ def is_scipy_available(): return _scipy_available +def is_xformers_available(): + return _xformers_available + + def is_accelerate_available(): return _accelerate_available @@ -280,6 +303,17 @@ def requires_backends(obj, backends): if failed: raise ImportError("".join(failed)) + if name in [ + "VersatileDiffusionTextToImagePipeline", + "VersatileDiffusionPipeline", + "VersatileDiffusionDualGuidedPipeline", + "StableDiffusionImageVariationPipeline", + ] and is_transformers_version("<", "4.25.0.dev0"): + raise ImportError( + f"You need to install `transformers` from 'main' in order to use {name}: \n```\n pip install" + " git+https://github.com/huggingface/transformers \n```" + ) + class DummyObject(type): """ @@ -291,3 +325,50 @@ def __getattr__(cls, key): if key.startswith("_"): return super().__getattr__(cls, key) requires_backends(cls, cls._backends) + + +# This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L319 +def compare_versions(library_or_version: Union[str, Version], operation: str, requirement_version: str): + """ + Args: + Compares a library version to some requirement using a given operation. + library_or_version (`str` or `packaging.version.Version`): + A library name or a version to check. + operation (`str`): + A string representation of an operator, such as `">"` or `"<="`. + requirement_version (`str`): + The version to compare the library version against + """ + if operation not in STR_OPERATION_TO_FUNC.keys(): + raise ValueError(f"`operation` must be one of {list(STR_OPERATION_TO_FUNC.keys())}, received {operation}") + operation = STR_OPERATION_TO_FUNC[operation] + if isinstance(library_or_version, str): + library_or_version = parse(importlib_metadata.version(library_or_version)) + return operation(library_or_version, parse(requirement_version)) + + +# This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L338 +def is_torch_version(operation: str, version: str): + """ + Args: + Compares the current PyTorch version to a given reference with an operation. + operation (`str`): + A string representation of an operator, such as `">"` or `"<="` + version (`str`): + A string version of PyTorch + """ + return compare_versions(parse(_torch_version), operation, version) + + +def is_transformers_version(operation: str, version: str): + """ + Args: + Compares the current Transformers version to a given reference with an operation. + operation (`str`): + A string representation of an operator, such as `">"` or `"<="` + version (`str`): + A string version of PyTorch + """ + if not _transformers_available: + return False + return compare_versions(parse(_transformers_version), operation, version) diff --git a/src/diffusers/utils/pil_utils.py b/src/diffusers/utils/pil_utils.py new file mode 100644 index 000000000000..39d0a15a4e2f --- /dev/null +++ b/src/diffusers/utils/pil_utils.py @@ -0,0 +1,21 @@ +import PIL.Image +import PIL.ImageOps +from packaging import version + + +if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): + PIL_INTERPOLATION = { + "linear": PIL.Image.Resampling.BILINEAR, + "bilinear": PIL.Image.Resampling.BILINEAR, + "bicubic": PIL.Image.Resampling.BICUBIC, + "lanczos": PIL.Image.Resampling.LANCZOS, + "nearest": PIL.Image.Resampling.NEAREST, + } +else: + PIL_INTERPOLATION = { + "linear": PIL.Image.LINEAR, + "bilinear": PIL.Image.BILINEAR, + "bicubic": PIL.Image.BICUBIC, + "lanczos": PIL.Image.LANCZOS, + "nearest": PIL.Image.NEAREST, + } diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index 682a7471f6af..bf398e5b6fe5 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -1,18 +1,23 @@ import inspect +import logging import os import random import re import unittest +import urllib.parse from distutils.util import strtobool +from io import BytesIO, StringIO from pathlib import Path from typing import Union +import numpy as np + import PIL.Image import PIL.ImageOps import requests from packaging import version -from .import_utils import is_flax_available, is_torch_available +from .import_utils import is_flax_available, is_onnx_available, is_torch_available global_rng = random.Random() @@ -27,7 +32,17 @@ ) if is_torch_higher_equal_than_1_12: - torch_device = "mps" if torch.backends.mps.is_available() else torch_device + # Some builds of torch 1.12 don't have the mps backend registered. See #892 for more details + mps_backend_registered = hasattr(torch.backends, "mps") + torch_device = "mps" if (mps_backend_registered and torch.backends.mps.is_available()) else torch_device + + +def torch_all_close(a, b, *args, **kwargs): + if not is_torch_available(): + raise ValueError("PyTorch needs to be installed to use this function.") + if not torch.allclose(a, b, *args, **kwargs): + assert False, f"Max diff is absolute {(a - b).abs().max()}. Diff tensor is {(a - b).abs()}." + return True def get_tests_dir(append_path=None): @@ -96,6 +111,20 @@ def slow(test_case): return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case) +def require_torch(test_case): + """ + Decorator marking a test that requires PyTorch. These tests are skipped when PyTorch isn't installed. + """ + return unittest.skipUnless(is_torch_available(), "test requires PyTorch")(test_case) + + +def require_torch_gpu(test_case): + """Decorator marking a test that requires CUDA and PyTorch.""" + return unittest.skipUnless(is_torch_available() and torch_device == "cuda", "test requires PyTorch+CUDA")( + test_case + ) + + def require_flax(test_case): """ Decorator marking a test that requires JAX & Flax. These tests are skipped when one / both are not installed @@ -103,6 +132,36 @@ def require_flax(test_case): return unittest.skipUnless(is_flax_available(), "test requires JAX & Flax")(test_case) +def require_onnxruntime(test_case): + """ + Decorator marking a test that requires onnxruntime. These tests are skipped when onnxruntime isn't installed. + """ + return unittest.skipUnless(is_onnx_available(), "test requires onnxruntime")(test_case) + + +def load_numpy(arry: Union[str, np.ndarray]) -> np.ndarray: + if isinstance(arry, str): + if arry.startswith("http://") or arry.startswith("https://"): + response = requests.get(arry) + response.raise_for_status() + arry = np.load(BytesIO(response.content)) + elif os.path.isfile(arry): + arry = np.load(arry) + else: + raise ValueError( + f"Incorrect path or url, URLs must start with `http://` or `https://`, and {arry} is not a valid path" + ) + elif isinstance(arry, np.ndarray): + pass + else: + raise ValueError( + "Incorrect format used for numpy ndarray. Should be an url linking to an image, a local path, or a" + " ndarray." + ) + + return arry + + def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image: """ Args: @@ -132,6 +191,15 @@ def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image: return image +def load_hf_numpy(path) -> np.ndarray: + if not path.startswith("http://") or path.startswith("https://"): + path = os.path.join( + "https://huggingface.co/datasets/fusing/diffusers-testing/resolve/main", urllib.parse.quote(path) + ) + + return load_numpy(path) + + # --- pytest conf functions --- # # to avoid multiple invocation from tests/conftest.py and examples/conftest.py - make sure it's called only once @@ -284,3 +352,42 @@ def summary_failures_short(tr): tr._tw = orig_writer tr.reportchars = orig_reportchars config.option.tbstyle = orig_tbstyle + + +class CaptureLogger: + """ + Args: + Context manager to capture `logging` streams + logger: 'logging` logger object + Returns: + The captured output is available via `self.out` + Example: + ```python + >>> from diffusers import logging + >>> from diffusers.testing_utils import CaptureLogger + + >>> msg = "Testing 1, 2, 3" + >>> logging.set_verbosity_info() + >>> logger = logging.get_logger("diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.py") + >>> with CaptureLogger(logger) as cl: + ... logger.info(msg) + >>> assert cl.out, msg + "\n" + ``` + """ + + def __init__(self, logger): + self.logger = logger + self.io = StringIO() + self.sh = logging.StreamHandler(self.io) + self.out = "" + + def __enter__(self): + self.logger.addHandler(self.sh) + return self + + def __exit__(self, *exc): + self.logger.removeHandler(self.sh) + self.out = self.io.getvalue() + + def __repr__(self): + return f"captured: {self.out}\n" diff --git a/tests/fixtures/custom_pipeline/pipeline.py b/tests/fixtures/custom_pipeline/pipeline.py index 10a22edaa490..e7429d0a1945 100644 --- a/tests/fixtures/custom_pipeline/pipeline.py +++ b/tests/fixtures/custom_pipeline/pipeline.py @@ -42,7 +42,6 @@ def __call__( self, batch_size: int = 1, generator: Optional[torch.Generator] = None, - eta: float = 0.0, num_inference_steps: int = 50, output_type: Optional[str] = "pil", return_dict: bool = True, @@ -89,7 +88,7 @@ def __call__( # 2. predict previous mean of image x_t-1 and add variance depending on eta # eta corresponds to η in paper and should be between [0, 1] # do x_t -> x_t-1 - image = self.scheduler.step(model_output, t, image, eta).prev_sample + image = self.scheduler.step(model_output, t, image).prev_sample image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() diff --git a/tests/fixtures/custom_pipeline/what_ever.py b/tests/fixtures/custom_pipeline/what_ever.py new file mode 100644 index 000000000000..e7429d0a1945 --- /dev/null +++ b/tests/fixtures/custom_pipeline/what_ever.py @@ -0,0 +1,101 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +# limitations under the License. + + +from typing import Optional, Tuple, Union + +import torch + +from diffusers.pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +class CustomLocalPipeline(DiffusionPipeline): + r""" + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Parameters: + unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of + [`DDPMScheduler`], or [`DDIMScheduler`]. + """ + + def __init__(self, unet, scheduler): + super().__init__() + self.register_modules(unet=unet, scheduler=scheduler) + + @torch.no_grad() + def __call__( + self, + batch_size: int = 1, + generator: Optional[torch.Generator] = None, + num_inference_steps: int = 50, + output_type: Optional[str] = "pil", + return_dict: bool = True, + **kwargs, + ) -> Union[ImagePipelineOutput, Tuple]: + r""" + Args: + batch_size (`int`, *optional*, defaults to 1): + The number of images to generate. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + eta (`float`, *optional*, defaults to 0.0): + The eta parameter which controls the scale of the variance (0 is DDIM and 1 is one type of DDPM). + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple. + + Returns: + [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if + `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the + generated images. + """ + + # Sample gaussian noise to begin loop + image = torch.randn( + (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), + generator=generator, + ) + image = image.to(self.device) + + # set step values + self.scheduler.set_timesteps(num_inference_steps) + + for t in self.progress_bar(self.scheduler.timesteps): + # 1. predict noise model_output + model_output = self.unet(image, t).sample + + # 2. predict previous mean of image x_t-1 and add variance depending on eta + # eta corresponds to η in paper and should be between [0, 1] + # do x_t -> x_t-1 + image = self.scheduler.step(model_output, t, image).prev_sample + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,), "This is a local test" + + return ImagePipelineOutput(images=image), "This is a local test" diff --git a/tests/models/__init__.py b/tests/models/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/test_models_unet_1d.py b/tests/models/test_models_unet_1d.py new file mode 100644 index 000000000000..089d935651a5 --- /dev/null +++ b/tests/models/test_models_unet_1d.py @@ -0,0 +1,276 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import UNet1DModel +from diffusers.utils import floats_tensor, slow, torch_device + +from ..test_modeling_common import ModelTesterMixin + + +torch.backends.cuda.matmul.allow_tf32 = False + + +class UNet1DModelTests(ModelTesterMixin, unittest.TestCase): + model_class = UNet1DModel + + @property + def dummy_input(self): + batch_size = 4 + num_features = 14 + seq_len = 16 + + noise = floats_tensor((batch_size, num_features, seq_len)).to(torch_device) + time_step = torch.tensor([10] * batch_size).to(torch_device) + + return {"sample": noise, "timestep": time_step} + + @property + def input_shape(self): + return (4, 14, 16) + + @property + def output_shape(self): + return (4, 14, 16) + + def test_ema_training(self): + pass + + def test_training(self): + pass + + @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") + def test_determinism(self): + super().test_determinism() + + @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") + def test_outputs_equivalence(self): + super().test_outputs_equivalence() + + @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") + def test_from_pretrained_save_pretrained(self): + super().test_from_pretrained_save_pretrained() + + @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") + def test_model_from_pretrained(self): + super().test_model_from_pretrained() + + @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") + def test_output(self): + super().test_output() + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "block_out_channels": (32, 64, 128, 256), + "in_channels": 14, + "out_channels": 14, + "time_embedding_type": "positional", + "use_timestep_embedding": True, + "flip_sin_to_cos": False, + "freq_shift": 1.0, + "out_block_type": "OutConv1DBlock", + "mid_block_type": "MidResTemporalBlock1D", + "down_block_types": ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"), + "up_block_types": ("UpResnetBlock1D", "UpResnetBlock1D", "UpResnetBlock1D"), + "act_fn": "mish", + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") + def test_from_pretrained_hub(self): + model, loading_info = UNet1DModel.from_pretrained( + "bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="unet" + ) + self.assertIsNotNone(model) + self.assertEqual(len(loading_info["missing_keys"]), 0) + + model.to(torch_device) + image = model(**self.dummy_input) + + assert image is not None, "Make sure output is not None" + + @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") + def test_output_pretrained(self): + model = UNet1DModel.from_pretrained("bglick13/hopper-medium-v2-value-function-hor32", subfolder="unet") + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + num_features = model.in_channels + seq_len = 16 + noise = torch.randn((1, seq_len, num_features)).permute( + 0, 2, 1 + ) # match original, we can update values and remove + time_step = torch.full((num_features,), 0) + + with torch.no_grad(): + output = model(noise, time_step).sample.permute(0, 2, 1) + + output_slice = output[0, -3:, -3:].flatten() + # fmt: off + expected_output_slice = torch.tensor([-2.137172, 1.1426016, 0.3688687, -0.766922, 0.7303146, 0.11038864, -0.4760633, 0.13270172, 0.02591348]) + # fmt: on + self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-3)) + + def test_forward_with_norm_groups(self): + # Not implemented yet for this UNet + pass + + @slow + def test_unet_1d_maestro(self): + model_id = "harmonai/maestro-150k" + model = UNet1DModel.from_pretrained(model_id, subfolder="unet") + model.to(torch_device) + + sample_size = 65536 + noise = torch.sin(torch.arange(sample_size)[None, None, :].repeat(1, 2, 1)).to(torch_device) + timestep = torch.tensor([1]).to(torch_device) + + with torch.no_grad(): + output = model(noise, timestep).sample + + output_sum = output.abs().sum() + output_max = output.abs().max() + + assert (output_sum - 224.0896).abs() < 4e-2 + assert (output_max - 0.0607).abs() < 4e-4 + + +class UNetRLModelTests(ModelTesterMixin, unittest.TestCase): + model_class = UNet1DModel + + @property + def dummy_input(self): + batch_size = 4 + num_features = 14 + seq_len = 16 + + noise = floats_tensor((batch_size, num_features, seq_len)).to(torch_device) + time_step = torch.tensor([10] * batch_size).to(torch_device) + + return {"sample": noise, "timestep": time_step} + + @property + def input_shape(self): + return (4, 14, 16) + + @property + def output_shape(self): + return (4, 14, 1) + + @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") + def test_determinism(self): + super().test_determinism() + + @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") + def test_outputs_equivalence(self): + super().test_outputs_equivalence() + + @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") + def test_from_pretrained_save_pretrained(self): + super().test_from_pretrained_save_pretrained() + + @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") + def test_model_from_pretrained(self): + super().test_model_from_pretrained() + + @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") + def test_output(self): + # UNetRL is a value-function is different output shape + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + output = model(**inputs_dict) + + if isinstance(output, dict): + output = output.sample + + self.assertIsNotNone(output) + expected_shape = torch.Size((inputs_dict["sample"].shape[0], 1)) + self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + + def test_ema_training(self): + pass + + def test_training(self): + pass + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "in_channels": 14, + "out_channels": 14, + "down_block_types": ["DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"], + "up_block_types": [], + "out_block_type": "ValueFunction", + "mid_block_type": "ValueFunctionMidBlock1D", + "block_out_channels": [32, 64, 128, 256], + "layers_per_block": 1, + "downsample_each_block": True, + "use_timestep_embedding": True, + "freq_shift": 1.0, + "flip_sin_to_cos": False, + "time_embedding_type": "positional", + "act_fn": "mish", + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") + def test_from_pretrained_hub(self): + value_function, vf_loading_info = UNet1DModel.from_pretrained( + "bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="value_function" + ) + self.assertIsNotNone(value_function) + self.assertEqual(len(vf_loading_info["missing_keys"]), 0) + + value_function.to(torch_device) + image = value_function(**self.dummy_input) + + assert image is not None, "Make sure output is not None" + + @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") + def test_output_pretrained(self): + value_function, vf_loading_info = UNet1DModel.from_pretrained( + "bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="value_function" + ) + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + num_features = value_function.in_channels + seq_len = 14 + noise = torch.randn((1, seq_len, num_features)).permute( + 0, 2, 1 + ) # match original, we can update values and remove + time_step = torch.full((num_features,), 0) + + with torch.no_grad(): + output = value_function(noise, time_step).sample + + # fmt: off + expected_output_slice = torch.tensor([165.25] * seq_len) + # fmt: on + self.assertTrue(torch.allclose(output, expected_output_slice, rtol=1e-3)) + + def test_forward_with_norm_groups(self): + # Not implemented yet for this UNet + pass diff --git a/tests/test_models_unet.py b/tests/models/test_models_unet_2d.py similarity index 53% rename from tests/test_models_unet.py rename to tests/models/test_models_unet_2d.py index 7a16a8592644..02c6d314bfff 100644 --- a/tests/test_models_unet.py +++ b/tests/models/test_models_unet_2d.py @@ -21,15 +21,25 @@ import torch from diffusers import UNet2DConditionModel, UNet2DModel -from diffusers.utils import floats_tensor, slow, torch_device +from diffusers.utils import ( + floats_tensor, + load_hf_numpy, + logging, + require_torch_gpu, + slow, + torch_all_close, + torch_device, +) +from parameterized import parameterized -from .test_modeling_common import ModelTesterMixin +from ..test_modeling_common import ModelTesterMixin +logger = logging.get_logger(__name__) torch.backends.cuda.matmul.allow_tf32 = False -class UnetModelTests(ModelTesterMixin, unittest.TestCase): +class Unet2DModelTests(ModelTesterMixin, unittest.TestCase): model_class = UNet2DModel @property @@ -66,28 +76,6 @@ def prepare_init_args_and_inputs_for_common(self): return init_dict, inputs_dict -# TODO(Patrick) - Re-add this test after having correctly added the final VE checkpoints -# def test_output_pretrained(self): -# model = UNet2DModel.from_pretrained("fusing/ddpm_dummy_update", subfolder="unet") -# model.eval() -# -# torch.manual_seed(0) -# if torch.cuda.is_available(): -# torch.cuda.manual_seed_all(0) -# -# noise = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size) -# time_step = torch.tensor([10]) -# -# with torch.no_grad(): -# output = model(noise, time_step).sample -# -# output_slice = output[0, -1, -3:, -3:].flatten() -# fmt: off -# expected_output_slice = torch.tensor([0.2891, -0.1899, 0.2595, -0.6214, 0.0968, -0.2622, 0.4688, 0.1311, 0.0053]) -# fmt: on -# self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2)) - - class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase): model_class = UNet2DModel @@ -137,9 +125,7 @@ def test_from_pretrained_hub(self): @unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU") def test_from_pretrained_accelerate(self): - model, _ = UNet2DModel.from_pretrained( - "fusing/unet-ldm-dummy-update", output_loading_info=True, device_map="auto" - ) + model, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True) model.to(torch_device) image = model(**self.dummy_input).sample @@ -147,9 +133,8 @@ def test_from_pretrained_accelerate(self): @unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU") def test_from_pretrained_accelerate_wont_change_results(self): - model_accelerate, _ = UNet2DModel.from_pretrained( - "fusing/unet-ldm-dummy-update", output_loading_info=True, device_map="auto" - ) + # by defautl model loading will use accelerate as `low_cpu_mem_usage=True` + model_accelerate, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True) model_accelerate.to(torch_device) model_accelerate.eval() @@ -170,12 +155,14 @@ def test_from_pretrained_accelerate_wont_change_results(self): torch.cuda.empty_cache() gc.collect() - model_normal_load, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True) + model_normal_load, _ = UNet2DModel.from_pretrained( + "fusing/unet-ldm-dummy-update", output_loading_info=True, low_cpu_mem_usage=False + ) model_normal_load.to(torch_device) model_normal_load.eval() arr_normal_load = model_normal_load(noise, time_step)["sample"] - assert torch.allclose(arr_accelerate, arr_normal_load, rtol=1e-3) + assert torch_all_close(arr_accelerate, arr_normal_load, rtol=1e-3) @unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU") def test_memory_footprint_gets_reduced(self): @@ -183,9 +170,8 @@ def test_memory_footprint_gets_reduced(self): gc.collect() tracemalloc.start() - model_accelerate, _ = UNet2DModel.from_pretrained( - "fusing/unet-ldm-dummy-update", output_loading_info=True, device_map="auto" - ) + # by defautl model loading will use accelerate as `low_cpu_mem_usage=True` + model_accelerate, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True) model_accelerate.to(torch_device) model_accelerate.eval() _, peak_accelerate = tracemalloc.get_traced_memory() @@ -194,7 +180,9 @@ def test_memory_footprint_gets_reduced(self): torch.cuda.empty_cache() gc.collect() - model_normal_load, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True) + model_normal_load, _ = UNet2DModel.from_pretrained( + "fusing/unet-ldm-dummy-update", output_loading_info=True, low_cpu_mem_usage=False + ) model_normal_load.to(torch_device) model_normal_load.eval() _, peak_normal = tracemalloc.get_traced_memory() @@ -226,7 +214,7 @@ def test_output_pretrained(self): expected_output_slice = torch.tensor([-13.3258, -20.1100, -15.9873, -17.6617, -23.0596, -17.9419, -13.3675, -16.1889, -12.3800]) # fmt: on - self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-3)) + self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-3)) class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase): @@ -306,32 +294,45 @@ def test_gradient_checkpointing(self): named_params = dict(model.named_parameters()) named_params_2 = dict(model_2.named_parameters()) for name, param in named_params.items(): - self.assertTrue(torch.allclose(param.grad.data, named_params_2[name].grad.data, atol=5e-5)) + self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5)) + def test_model_with_attention_head_dim_tuple(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() -# TODO(Patrick) - Re-add this test after having cleaned up LDM -# def test_output_pretrained_spatial_transformer(self): -# model = UNetLDMModel.from_pretrained("fusing/unet-ldm-dummy-spatial") -# model.eval() -# -# torch.manual_seed(0) -# if torch.cuda.is_available(): -# torch.cuda.manual_seed_all(0) -# -# noise = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size) -# context = torch.ones((1, 16, 64), dtype=torch.float32) -# time_step = torch.tensor([10] * noise.shape[0]) -# -# with torch.no_grad(): -# output = model(noise, time_step, context=context) -# -# output_slice = output[0, -1, -3:, -3:].flatten() -# fmt: off -# expected_output_slice = torch.tensor([61.3445, 56.9005, 29.4339, 59.5497, 60.7375, 34.1719, 48.1951, 42.6569, 25.0890]) -# fmt: on -# -# self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) -# + init_dict["attention_head_dim"] = (8, 16) + + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + output = model(**inputs_dict) + + if isinstance(output, dict): + output = output.sample + + self.assertIsNotNone(output) + expected_shape = inputs_dict["sample"].shape + self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + + def test_model_with_use_linear_projection(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["use_linear_projection"] = True + + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + output = model(**inputs_dict) + + if isinstance(output, dict): + output = output.sample + + self.assertIsNotNone(output) + expected_shape = inputs_dict["sample"].shape + self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") class NCSNppModelTests(ModelTesterMixin, unittest.TestCase): @@ -419,7 +420,7 @@ def test_output_pretrained_ve_mid(self): expected_output_slice = torch.tensor([-4836.2231, -6487.1387, -3816.7969, -7964.9253, -10966.2842, -20043.6016, 8137.0571, 2340.3499, 544.6114]) # fmt: on - self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2)) + self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2)) def test_output_pretrained_ve_large(self): model = UNet2DModel.from_pretrained("fusing/ncsnpp-ffhq-ve-dummy-update") @@ -444,8 +445,197 @@ def test_output_pretrained_ve_large(self): expected_output_slice = torch.tensor([-0.0325, -0.0900, -0.0869, -0.0332, -0.0725, -0.0270, -0.0101, 0.0227, 0.0256]) # fmt: on - self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2)) + self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2)) def test_forward_with_norm_groups(self): # not required for this model pass + + +@slow +class UNet2DConditionModelIntegrationTests(unittest.TestCase): + def get_file_format(self, seed, shape): + return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy" + + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def get_latents(self, seed=0, shape=(4, 4, 64, 64), fp16=False): + dtype = torch.float16 if fp16 else torch.float32 + image = torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype) + return image + + def get_unet_model(self, fp16=False, model_id="CompVis/stable-diffusion-v1-4"): + revision = "fp16" if fp16 else None + torch_dtype = torch.float16 if fp16 else torch.float32 + + model = UNet2DConditionModel.from_pretrained( + model_id, subfolder="unet", torch_dtype=torch_dtype, revision=revision + ) + model.to(torch_device).eval() + + return model + + def get_encoder_hidden_states(self, seed=0, shape=(4, 77, 768), fp16=False): + dtype = torch.float16 if fp16 else torch.float32 + hidden_states = torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype) + return hidden_states + + @parameterized.expand( + [ + # fmt: off + [33, 4, [-0.4424, 0.1510, -0.1937, 0.2118, 0.3746, -0.3957, 0.0160, -0.0435]], + [47, 0.55, [-0.1508, 0.0379, -0.3075, 0.2540, 0.3633, -0.0821, 0.1719, -0.0207]], + [21, 0.89, [-0.6479, 0.6364, -0.3464, 0.8697, 0.4443, -0.6289, -0.0091, 0.1778]], + [9, 1000, [0.8888, -0.5659, 0.5834, -0.7469, 1.1912, -0.3923, 1.1241, -0.4424]], + # fmt: on + ] + ) + @require_torch_gpu + def test_compvis_sd_v1_4(self, seed, timestep, expected_slice): + model = self.get_unet_model(model_id="CompVis/stable-diffusion-v1-4") + latents = self.get_latents(seed) + encoder_hidden_states = self.get_encoder_hidden_states(seed) + + with torch.no_grad(): + sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample + + assert sample.shape == latents.shape + + output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu() + expected_output_slice = torch.tensor(expected_slice) + + assert torch_all_close(output_slice, expected_output_slice, atol=1e-3) + + @parameterized.expand( + [ + # fmt: off + [83, 4, [-0.2323, -0.1304, 0.0813, -0.3093, -0.0919, -0.1571, -0.1125, -0.5806]], + [17, 0.55, [-0.0831, -0.2443, 0.0901, -0.0919, 0.3396, 0.0103, -0.3743, 0.0701]], + [8, 0.89, [-0.4863, 0.0859, 0.0875, -0.1658, 0.9199, -0.0114, 0.4839, 0.4639]], + [3, 1000, [-0.5649, 0.2402, -0.5518, 0.1248, 1.1328, -0.2443, -0.0325, -1.0078]], + # fmt: on + ] + ) + @require_torch_gpu + def test_compvis_sd_v1_4_fp16(self, seed, timestep, expected_slice): + model = self.get_unet_model(model_id="CompVis/stable-diffusion-v1-4", fp16=True) + latents = self.get_latents(seed, fp16=True) + encoder_hidden_states = self.get_encoder_hidden_states(seed, fp16=True) + + with torch.no_grad(): + sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample + + assert sample.shape == latents.shape + + output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu() + expected_output_slice = torch.tensor(expected_slice) + + assert torch_all_close(output_slice, expected_output_slice, atol=5e-3) + + @parameterized.expand( + [ + # fmt: off + [33, 4, [-0.4430, 0.1570, -0.1867, 0.2376, 0.3205, -0.3681, 0.0525, -0.0722]], + [47, 0.55, [-0.1415, 0.0129, -0.3136, 0.2257, 0.3430, -0.0536, 0.2114, -0.0436]], + [21, 0.89, [-0.7091, 0.6664, -0.3643, 0.9032, 0.4499, -0.6541, 0.0139, 0.1750]], + [9, 1000, [0.8878, -0.5659, 0.5844, -0.7442, 1.1883, -0.3927, 1.1192, -0.4423]], + # fmt: on + ] + ) + @require_torch_gpu + def test_compvis_sd_v1_5(self, seed, timestep, expected_slice): + model = self.get_unet_model(model_id="runwayml/stable-diffusion-v1-5") + latents = self.get_latents(seed) + encoder_hidden_states = self.get_encoder_hidden_states(seed) + + with torch.no_grad(): + sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample + + assert sample.shape == latents.shape + + output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu() + expected_output_slice = torch.tensor(expected_slice) + + assert torch_all_close(output_slice, expected_output_slice, atol=1e-3) + + @parameterized.expand( + [ + # fmt: off + [83, 4, [-0.2695, -0.1669, 0.0073, -0.3181, -0.1187, -0.1676, -0.1395, -0.5972]], + [17, 0.55, [-0.1290, -0.2588, 0.0551, -0.0916, 0.3286, 0.0238, -0.3669, 0.0322]], + [8, 0.89, [-0.5283, 0.1198, 0.0870, -0.1141, 0.9189, -0.0150, 0.5474, 0.4319]], + [3, 1000, [-0.5601, 0.2411, -0.5435, 0.1268, 1.1338, -0.2427, -0.0280, -1.0020]], + # fmt: on + ] + ) + @require_torch_gpu + def test_compvis_sd_v1_5_fp16(self, seed, timestep, expected_slice): + model = self.get_unet_model(model_id="runwayml/stable-diffusion-v1-5", fp16=True) + latents = self.get_latents(seed, fp16=True) + encoder_hidden_states = self.get_encoder_hidden_states(seed, fp16=True) + + with torch.no_grad(): + sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample + + assert sample.shape == latents.shape + + output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu() + expected_output_slice = torch.tensor(expected_slice) + + assert torch_all_close(output_slice, expected_output_slice, atol=5e-3) + + @parameterized.expand( + [ + # fmt: off + [33, 4, [-0.7639, 0.0106, -0.1615, -0.3487, -0.0423, -0.7972, 0.0085, -0.4858]], + [47, 0.55, [-0.6564, 0.0795, -1.9026, -0.6258, 1.8235, 1.2056, 1.2169, 0.9073]], + [21, 0.89, [0.0327, 0.4399, -0.6358, 0.3417, 0.4120, -0.5621, -0.0397, -1.0430]], + [9, 1000, [0.1600, 0.7303, -1.0556, -0.3515, -0.7440, -1.2037, -1.8149, -1.8931]], + # fmt: on + ] + ) + @require_torch_gpu + def test_compvis_sd_inpaint(self, seed, timestep, expected_slice): + model = self.get_unet_model(model_id="runwayml/stable-diffusion-inpainting") + latents = self.get_latents(seed, shape=(4, 9, 64, 64)) + encoder_hidden_states = self.get_encoder_hidden_states(seed) + + with torch.no_grad(): + sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample + + assert sample.shape == (4, 4, 64, 64) + + output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu() + expected_output_slice = torch.tensor(expected_slice) + + assert torch_all_close(output_slice, expected_output_slice, atol=1e-3) + + @parameterized.expand( + [ + # fmt: off + [83, 4, [-0.1047, -1.7227, 0.1067, 0.0164, -0.5698, -0.4172, -0.1388, 1.1387]], + [17, 0.55, [0.0975, -0.2856, -0.3508, -0.4600, 0.3376, 0.2930, -0.2747, -0.7026]], + [8, 0.89, [-0.0952, 0.0183, -0.5825, -0.1981, 0.1131, 0.4668, -0.0395, -0.3486]], + [3, 1000, [0.4790, 0.4949, -1.0732, -0.7158, 0.7959, -0.9478, 0.1105, -0.9741]], + # fmt: on + ] + ) + @require_torch_gpu + def test_compvis_sd_inpaint_fp16(self, seed, timestep, expected_slice): + model = self.get_unet_model(model_id="runwayml/stable-diffusion-inpainting", fp16=True) + latents = self.get_latents(seed, shape=(4, 9, 64, 64), fp16=True) + encoder_hidden_states = self.get_encoder_hidden_states(seed, fp16=True) + + with torch.no_grad(): + sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample + + assert sample.shape == (4, 4, 64, 64) + + output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu() + expected_output_slice = torch.tensor(expected_slice) + + assert torch_all_close(output_slice, expected_output_slice, atol=5e-3) diff --git a/tests/models/test_models_vae.py b/tests/models/test_models_vae.py new file mode 100644 index 000000000000..169365756103 --- /dev/null +++ b/tests/models/test_models_vae.py @@ -0,0 +1,306 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import unittest + +import torch + +from diffusers import AutoencoderKL +from diffusers.modeling_utils import ModelMixin +from diffusers.utils import floats_tensor, load_hf_numpy, require_torch_gpu, slow, torch_all_close, torch_device +from parameterized import parameterized + +from ..test_modeling_common import ModelTesterMixin + + +torch.backends.cuda.matmul.allow_tf32 = False + + +class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase): + model_class = AutoencoderKL + + @property + def dummy_input(self): + batch_size = 4 + num_channels = 3 + sizes = (32, 32) + + image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) + + return {"sample": image} + + @property + def input_shape(self): + return (3, 32, 32) + + @property + def output_shape(self): + return (3, 32, 32) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "block_out_channels": [32, 64], + "in_channels": 3, + "out_channels": 3, + "down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"], + "up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"], + "latent_channels": 4, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_forward_signature(self): + pass + + def test_training(self): + pass + + def test_from_pretrained_hub(self): + model, loading_info = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy", output_loading_info=True) + self.assertIsNotNone(model) + self.assertEqual(len(loading_info["missing_keys"]), 0) + + model.to(torch_device) + image = model(**self.dummy_input) + + assert image is not None, "Make sure output is not None" + + def test_output_pretrained(self): + model = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy") + model = model.to(torch_device) + model.eval() + + # One-time warmup pass (see #372) + if torch_device == "mps" and isinstance(model, ModelMixin): + image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size) + image = image.to(torch_device) + with torch.no_grad(): + _ = model(image, sample_posterior=True).sample + generator = torch.manual_seed(0) + else: + generator = torch.Generator(device=torch_device).manual_seed(0) + + image = torch.randn( + 1, + model.config.in_channels, + model.config.sample_size, + model.config.sample_size, + generator=torch.manual_seed(0), + ) + image = image.to(torch_device) + with torch.no_grad(): + output = model(image, sample_posterior=True, generator=generator).sample + + output_slice = output[0, -1, -3:, -3:].flatten().cpu() + + # Since the VAE Gaussian prior's generator is seeded on the appropriate device, + # the expected output slices are not the same for CPU and GPU. + if torch_device == "mps": + expected_output_slice = torch.tensor( + [ + -4.0078e-01, + -3.8323e-04, + -1.2681e-01, + -1.1462e-01, + 2.0095e-01, + 1.0893e-01, + -8.8247e-02, + -3.0361e-01, + -9.8644e-03, + ] + ) + elif torch_device == "cpu": + expected_output_slice = torch.tensor( + [-0.1352, 0.0878, 0.0419, -0.0818, -0.1069, 0.0688, -0.1458, -0.4446, -0.0026] + ) + else: + expected_output_slice = torch.tensor( + [-0.2421, 0.4642, 0.2507, -0.0438, 0.0682, 0.3160, -0.2018, -0.0727, 0.2485] + ) + + self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2)) + + +@slow +class AutoencoderKLIntegrationTests(unittest.TestCase): + def get_file_format(self, seed, shape): + return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy" + + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def get_sd_image(self, seed=0, shape=(4, 3, 512, 512), fp16=False): + dtype = torch.float16 if fp16 else torch.float32 + image = torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype) + return image + + def get_sd_vae_model(self, model_id="CompVis/stable-diffusion-v1-4", fp16=False): + revision = "fp16" if fp16 else None + torch_dtype = torch.float16 if fp16 else torch.float32 + + model = AutoencoderKL.from_pretrained( + model_id, + subfolder="vae", + torch_dtype=torch_dtype, + revision=revision, + ) + model.to(torch_device).eval() + + return model + + def get_generator(self, seed=0): + return torch.Generator(device=torch_device).manual_seed(seed) + + @parameterized.expand( + [ + # fmt: off + [33, [-0.1603, 0.9878, -0.0495, -0.0790, -0.2709, 0.8375, -0.2060, -0.0824]], + [47, [-0.2376, 0.1168, 0.1332, -0.4840, -0.2508, -0.0791, -0.0493, -0.4089]], + # fmt: on + ] + ) + def test_stable_diffusion(self, seed, expected_slice): + model = self.get_sd_vae_model() + image = self.get_sd_image(seed) + generator = self.get_generator(seed) + + with torch.no_grad(): + sample = model(image, generator=generator, sample_posterior=True).sample + + assert sample.shape == image.shape + + output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu() + expected_output_slice = torch.tensor(expected_slice) + + assert torch_all_close(output_slice, expected_output_slice, atol=1e-3) + + @parameterized.expand( + [ + # fmt: off + [33, [-0.0513, 0.0289, 1.3799, 0.2166, -0.2573, -0.0871, 0.5103, -0.0999]], + [47, [-0.4128, -0.1320, -0.3704, 0.1965, -0.4116, -0.2332, -0.3340, 0.2247]], + # fmt: on + ] + ) + @require_torch_gpu + def test_stable_diffusion_fp16(self, seed, expected_slice): + model = self.get_sd_vae_model(fp16=True) + image = self.get_sd_image(seed, fp16=True) + generator = self.get_generator(seed) + + with torch.no_grad(): + sample = model(image, generator=generator, sample_posterior=True).sample + + assert sample.shape == image.shape + + output_slice = sample[-1, -2:, :2, -2:].flatten().float().cpu() + expected_output_slice = torch.tensor(expected_slice) + + assert torch_all_close(output_slice, expected_output_slice, atol=1e-2) + + @parameterized.expand( + [ + # fmt: off + [33, [-0.1609, 0.9866, -0.0487, -0.0777, -0.2716, 0.8368, -0.2055, -0.0814]], + [47, [-0.2377, 0.1147, 0.1333, -0.4841, -0.2506, -0.0805, -0.0491, -0.4085]], + # fmt: on + ] + ) + def test_stable_diffusion_mode(self, seed, expected_slice): + model = self.get_sd_vae_model() + image = self.get_sd_image(seed) + + with torch.no_grad(): + sample = model(image).sample + + assert sample.shape == image.shape + + output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu() + expected_output_slice = torch.tensor(expected_slice) + + assert torch_all_close(output_slice, expected_output_slice, atol=1e-3) + + @parameterized.expand( + [ + # fmt: off + [13, [-0.2051, -0.1803, -0.2311, -0.2114, -0.3292, -0.3574, -0.2953, -0.3323]], + [37, [-0.2632, -0.2625, -0.2199, -0.2741, -0.4539, -0.4990, -0.3720, -0.4925]], + # fmt: on + ] + ) + @require_torch_gpu + def test_stable_diffusion_decode(self, seed, expected_slice): + model = self.get_sd_vae_model() + encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64)) + + with torch.no_grad(): + sample = model.decode(encoding).sample + + assert list(sample.shape) == [3, 3, 512, 512] + + output_slice = sample[-1, -2:, :2, -2:].flatten().cpu() + expected_output_slice = torch.tensor(expected_slice) + + assert torch_all_close(output_slice, expected_output_slice, atol=1e-3) + + @parameterized.expand( + [ + # fmt: off + [27, [-0.0369, 0.0207, -0.0776, -0.0682, -0.1747, -0.1930, -0.1465, -0.2039]], + [16, [-0.1628, -0.2134, -0.2747, -0.2642, -0.3774, -0.4404, -0.3687, -0.4277]], + # fmt: on + ] + ) + def test_stable_diffusion_decode_fp16(self, seed, expected_slice): + model = self.get_sd_vae_model(fp16=True) + encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64), fp16=True) + + with torch.no_grad(): + sample = model.decode(encoding).sample + + assert list(sample.shape) == [3, 3, 512, 512] + + output_slice = sample[-1, -2:, :2, -2:].flatten().float().cpu() + expected_output_slice = torch.tensor(expected_slice) + + assert torch_all_close(output_slice, expected_output_slice, atol=5e-3) + + @parameterized.expand( + [ + # fmt: off + [33, [-0.3001, 0.0918, -2.6984, -3.9720, -3.2099, -5.0353, 1.7338, -0.2065, 3.4267]], + [47, [-1.5030, -4.3871, -6.0355, -9.1157, -1.6661, -2.7853, 2.1607, -5.0823, 2.5633]], + # fmt: on + ] + ) + def test_stable_diffusion_encode_sample(self, seed, expected_slice): + model = self.get_sd_vae_model() + image = self.get_sd_image(seed) + generator = self.get_generator(seed) + + with torch.no_grad(): + dist = model.encode(image).latent_dist + sample = dist.sample(generator=generator) + + assert list(sample.shape) == [image.shape[0], 4] + [i // 8 for i in image.shape[2:]] + + output_slice = sample[0, -1, -3:, -3:].flatten().cpu() + expected_output_slice = torch.tensor(expected_slice) + + assert torch_all_close(output_slice, expected_output_slice, atol=1e-3) diff --git a/tests/test_models_vae_flax.py b/tests/models/test_models_vae_flax.py similarity index 94% rename from tests/test_models_vae_flax.py rename to tests/models/test_models_vae_flax.py index e5c56b61a5a4..8fedb85eccfc 100644 --- a/tests/test_models_vae_flax.py +++ b/tests/models/test_models_vae_flax.py @@ -4,7 +4,7 @@ from diffusers.utils import is_flax_available from diffusers.utils.testing_utils import require_flax -from .test_modeling_common_flax import FlaxModelTesterMixin +from ..test_modeling_common_flax import FlaxModelTesterMixin if is_flax_available(): diff --git a/tests/test_models_vq.py b/tests/models/test_models_vq.py similarity index 98% rename from tests/test_models_vq.py rename to tests/models/test_models_vq.py index 9a2094d46cb4..f58e90469885 100644 --- a/tests/test_models_vq.py +++ b/tests/models/test_models_vq.py @@ -20,7 +20,7 @@ from diffusers import VQModel from diffusers.utils import floats_tensor, torch_device -from .test_modeling_common import ModelTesterMixin +from ..test_modeling_common import ModelTesterMixin torch.backends.cuda.matmul.allow_tf32 = False diff --git a/tests/pipelines/__init__.py b/tests/pipelines/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/altdiffusion/__init__.py b/tests/pipelines/altdiffusion/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/altdiffusion/test_alt_diffusion.py b/tests/pipelines/altdiffusion/test_alt_diffusion.py new file mode 100644 index 000000000000..91fe76444920 --- /dev/null +++ b/tests/pipelines/altdiffusion/test_alt_diffusion.py @@ -0,0 +1,347 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import random +import unittest + +import numpy as np +import torch + +from diffusers import AltDiffusionPipeline, AutoencoderKL, DDIMScheduler, PNDMScheduler, UNet2DConditionModel +from diffusers.pipelines.alt_diffusion.modeling_roberta_series import ( + RobertaSeriesConfig, + RobertaSeriesModelWithTransformation, +) +from diffusers.utils import floats_tensor, slow, torch_device +from diffusers.utils.testing_utils import require_torch_gpu +from transformers import XLMRobertaTokenizer + +from ...test_pipelines_common import PipelineTesterMixin + + +torch.backends.cuda.matmul.allow_tf32 = False + + +class AltDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + @property + def dummy_image(self): + batch_size = 1 + num_channels = 3 + sizes = (32, 32) + + image = floats_tensor((batch_size, num_channels) + sizes, rng=random.Random(0)).to(torch_device) + return image + + @property + def dummy_cond_unet(self): + torch.manual_seed(0) + model = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=32, + ) + return model + + @property + def dummy_cond_unet_inpaint(self): + torch.manual_seed(0) + model = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=9, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=32, + ) + return model + + @property + def dummy_vae(self): + torch.manual_seed(0) + model = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + ) + return model + + @property + def dummy_text_encoder(self): + torch.manual_seed(0) + config = RobertaSeriesConfig( + hidden_size=32, + project_dim=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + vocab_size=5002, + ) + return RobertaSeriesModelWithTransformation(config) + + @property + def dummy_extractor(self): + def extract(*args, **kwargs): + class Out: + def __init__(self): + self.pixel_values = torch.ones([0]) + + def to(self, device): + self.pixel_values.to(device) + return self + + return Out() + + return extract + + def test_alt_diffusion_ddim(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unet = self.dummy_cond_unet + scheduler = DDIMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + ) + + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = XLMRobertaTokenizer.from_pretrained("hf-internal-testing/tiny-xlm-roberta") + tokenizer.model_max_length = 77 + + # make sure here that pndm scheduler skips prk + alt_pipe = AltDiffusionPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=self.dummy_extractor, + ) + alt_pipe = alt_pipe.to(device) + alt_pipe.set_progress_bar_config(disable=None) + + prompt = "A photo of an astronaut" + + generator = torch.Generator(device=device).manual_seed(0) + output = alt_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np") + image = output.images + + generator = torch.Generator(device=device).manual_seed(0) + image_from_tuple = alt_pipe( + [prompt], + generator=generator, + guidance_scale=6.0, + num_inference_steps=2, + output_type="np", + return_dict=False, + )[0] + + image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + expected_slice = np.array( + [0.5748162, 0.60447145, 0.48821217, 0.50100636, 0.5431185, 0.45763683, 0.49657696, 0.48132733, 0.47573093] + ) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + + def test_alt_diffusion_pndm(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unet = self.dummy_cond_unet + scheduler = PNDMScheduler(skip_prk_steps=True) + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = XLMRobertaTokenizer.from_pretrained("hf-internal-testing/tiny-xlm-roberta") + tokenizer.model_max_length = 77 + + # make sure here that pndm scheduler skips prk + alt_pipe = AltDiffusionPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=self.dummy_extractor, + ) + alt_pipe = alt_pipe.to(device) + alt_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device=device).manual_seed(0) + output = alt_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np") + + image = output.images + + generator = torch.Generator(device=device).manual_seed(0) + image_from_tuple = alt_pipe( + [prompt], + generator=generator, + guidance_scale=6.0, + num_inference_steps=2, + output_type="np", + return_dict=False, + )[0] + + image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + expected_slice = np.array( + [0.51605093, 0.5707241, 0.47365507, 0.50578886, 0.5633877, 0.4642503, 0.5182081, 0.48763484, 0.49084237] + ) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + + @unittest.skipIf(torch_device != "cuda", "This test requires a GPU") + def test_alt_diffusion_fp16(self): + """Test that stable diffusion works with fp16""" + unet = self.dummy_cond_unet + scheduler = PNDMScheduler(skip_prk_steps=True) + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = XLMRobertaTokenizer.from_pretrained("hf-internal-testing/tiny-xlm-roberta") + tokenizer.model_max_length = 77 + + # put models in fp16 + unet = unet.half() + vae = vae.half() + bert = bert.half() + + # make sure here that pndm scheduler skips prk + alt_pipe = AltDiffusionPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=self.dummy_extractor, + ) + alt_pipe = alt_pipe.to(torch_device) + alt_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device=torch_device).manual_seed(0) + image = alt_pipe([prompt], generator=generator, num_inference_steps=2, output_type="np").images + + assert image.shape == (1, 64, 64, 3) + + +@slow +@require_torch_gpu +class AltDiffusionPipelineIntegrationTests(unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_alt_diffusion(self): + # make sure here that pndm scheduler skips prk + alt_pipe = AltDiffusionPipeline.from_pretrained("BAAI/AltDiffusion", safety_checker=None) + alt_pipe = alt_pipe.to(torch_device) + alt_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device=torch_device).manual_seed(0) + with torch.autocast("cuda"): + output = alt_pipe( + [prompt], generator=generator, guidance_scale=6.0, num_inference_steps=20, output_type="np" + ) + + image = output.images + + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 512, 512, 3) + expected_slice = np.array( + [0.8720703, 0.87109375, 0.87402344, 0.87109375, 0.8779297, 0.8925781, 0.8823242, 0.8808594, 0.8613281] + ) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + def test_alt_diffusion_fast_ddim(self): + scheduler = DDIMScheduler.from_pretrained("BAAI/AltDiffusion", subfolder="scheduler") + + alt_pipe = AltDiffusionPipeline.from_pretrained("BAAI/AltDiffusion", scheduler=scheduler, safety_checker=None) + alt_pipe = alt_pipe.to(torch_device) + alt_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device=torch_device).manual_seed(0) + + with torch.autocast("cuda"): + output = alt_pipe([prompt], generator=generator, num_inference_steps=2, output_type="numpy") + image = output.images + + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 512, 512, 3) + expected_slice = np.array( + [0.9267578, 0.9301758, 0.9013672, 0.9345703, 0.92578125, 0.94433594, 0.9423828, 0.9423828, 0.9160156] + ) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + def test_alt_diffusion_text2img_pipeline_fp16(self): + torch.cuda.reset_peak_memory_stats() + model_id = "BAAI/AltDiffusion" + pipe = AltDiffusionPipeline.from_pretrained( + model_id, revision="fp16", torch_dtype=torch.float16, safety_checker=None + ) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + prompt = "a photograph of an astronaut riding a horse" + + generator = torch.Generator(device=torch_device).manual_seed(0) + output_chunked = pipe( + [prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy" + ) + image_chunked = output_chunked.images + + generator = torch.Generator(device=torch_device).manual_seed(0) + with torch.autocast(torch_device): + output = pipe( + [prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy" + ) + image = output.images + + # Make sure results are close enough + diff = np.abs(image_chunked.flatten() - image.flatten()) + # They ARE different since ops are not run always at the same precision + # however, they should be extremely close. + assert diff.mean() < 2e-2 diff --git a/tests/pipelines/altdiffusion/test_alt_diffusion_img2img.py b/tests/pipelines/altdiffusion/test_alt_diffusion_img2img.py new file mode 100644 index 000000000000..0dab14b31716 --- /dev/null +++ b/tests/pipelines/altdiffusion/test_alt_diffusion_img2img.py @@ -0,0 +1,256 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import random +import unittest + +import numpy as np +import torch + +from diffusers import AltDiffusionImg2ImgPipeline, AutoencoderKL, PNDMScheduler, UNet2DConditionModel +from diffusers.pipelines.alt_diffusion.modeling_roberta_series import ( + RobertaSeriesConfig, + RobertaSeriesModelWithTransformation, +) +from diffusers.utils import floats_tensor, load_image, load_numpy, slow, torch_device +from diffusers.utils.testing_utils import require_torch_gpu +from transformers import XLMRobertaTokenizer + +from ...test_pipelines_common import PipelineTesterMixin + + +torch.backends.cuda.matmul.allow_tf32 = False + + +class AltDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + @property + def dummy_image(self): + batch_size = 1 + num_channels = 3 + sizes = (32, 32) + + image = floats_tensor((batch_size, num_channels) + sizes, rng=random.Random(0)).to(torch_device) + return image + + @property + def dummy_cond_unet(self): + torch.manual_seed(0) + model = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=32, + ) + return model + + @property + def dummy_vae(self): + torch.manual_seed(0) + model = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + ) + return model + + @property + def dummy_text_encoder(self): + torch.manual_seed(0) + config = RobertaSeriesConfig( + hidden_size=32, + project_dim=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=5006, + ) + return RobertaSeriesModelWithTransformation(config) + + @property + def dummy_extractor(self): + def extract(*args, **kwargs): + class Out: + def __init__(self): + self.pixel_values = torch.ones([0]) + + def to(self, device): + self.pixel_values.to(device) + return self + + return Out() + + return extract + + def test_stable_diffusion_img2img_default_case(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unet = self.dummy_cond_unet + scheduler = PNDMScheduler(skip_prk_steps=True) + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = XLMRobertaTokenizer.from_pretrained("hf-internal-testing/tiny-xlm-roberta") + tokenizer.model_max_length = 77 + + init_image = self.dummy_image.to(device) + + # make sure here that pndm scheduler skips prk + alt_pipe = AltDiffusionImg2ImgPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=self.dummy_extractor, + ) + alt_pipe = alt_pipe.to(device) + alt_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device=device).manual_seed(0) + output = alt_pipe( + [prompt], + generator=generator, + guidance_scale=6.0, + num_inference_steps=2, + output_type="np", + init_image=init_image, + ) + + image = output.images + + generator = torch.Generator(device=device).manual_seed(0) + image_from_tuple = alt_pipe( + [prompt], + generator=generator, + guidance_scale=6.0, + num_inference_steps=2, + output_type="np", + init_image=init_image, + return_dict=False, + )[0] + + image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + + assert image.shape == (1, 32, 32, 3) + expected_slice = np.array( + [0.41293705, 0.38656747, 0.40876025, 0.4782187, 0.4656803, 0.41394007, 0.4142093, 0.47150758, 0.4570448] + ) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1.5e-3 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1.5e-3 + + @unittest.skipIf(torch_device != "cuda", "This test requires a GPU") + def test_stable_diffusion_img2img_fp16(self): + """Test that stable diffusion img2img works with fp16""" + unet = self.dummy_cond_unet + scheduler = PNDMScheduler(skip_prk_steps=True) + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = XLMRobertaTokenizer.from_pretrained("hf-internal-testing/tiny-xlm-roberta") + tokenizer.model_max_length = 77 + + init_image = self.dummy_image.to(torch_device) + + # put models in fp16 + unet = unet.half() + vae = vae.half() + bert = bert.half() + + # make sure here that pndm scheduler skips prk + alt_pipe = AltDiffusionImg2ImgPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=self.dummy_extractor, + ) + alt_pipe = alt_pipe.to(torch_device) + alt_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device=torch_device).manual_seed(0) + image = alt_pipe( + [prompt], + generator=generator, + num_inference_steps=2, + output_type="np", + init_image=init_image, + ).images + + assert image.shape == (1, 32, 32, 3) + + +@slow +@require_torch_gpu +class AltDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_stable_diffusion_img2img_pipeline_default(self): + init_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/img2img/sketch-mountains-input.jpg" + ) + init_image = init_image.resize((768, 512)) + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/img2img/fantasy_landscape_alt.npy" + ) + + model_id = "BAAI/AltDiffusion" + pipe = AltDiffusionImg2ImgPipeline.from_pretrained( + model_id, + safety_checker=None, + ) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + prompt = "A fantasy landscape, trending on artstation" + + generator = torch.Generator(device=torch_device).manual_seed(0) + output = pipe( + prompt=prompt, + init_image=init_image, + strength=0.75, + guidance_scale=7.5, + generator=generator, + output_type="np", + ) + image = output.images[0] + + assert image.shape == (512, 768, 3) + # img2img is flaky across GPUs even in fp32, so using MAE here + assert np.abs(expected_image - image).max() < 1e-3 diff --git a/tests/pipelines/dance_diffusion/test_dance_diffusion.py b/tests/pipelines/dance_diffusion/test_dance_diffusion.py new file mode 100644 index 000000000000..a63ef84c63f5 --- /dev/null +++ b/tests/pipelines/dance_diffusion/test_dance_diffusion.py @@ -0,0 +1,122 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import unittest + +import numpy as np +import torch + +from diffusers import DanceDiffusionPipeline, IPNDMScheduler, UNet1DModel +from diffusers.utils import slow, torch_device +from diffusers.utils.testing_utils import require_torch_gpu + + +torch.backends.cuda.matmul.allow_tf32 = False + + +class PipelineFastTests(unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + @property + def dummy_unet(self): + torch.manual_seed(0) + model = UNet1DModel( + block_out_channels=(32, 32, 64), + extra_in_channels=16, + sample_size=512, + sample_rate=16_000, + in_channels=2, + out_channels=2, + flip_sin_to_cos=True, + use_timestep_embedding=False, + time_embedding_type="fourier", + mid_block_type="UNetMidBlock1D", + down_block_types=["DownBlock1DNoSkip"] + ["DownBlock1D"] + ["AttnDownBlock1D"], + up_block_types=["AttnUpBlock1D"] + ["UpBlock1D"] + ["UpBlock1DNoSkip"], + ) + return model + + def test_dance_diffusion(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + scheduler = IPNDMScheduler() + + pipe = DanceDiffusionPipeline(unet=self.dummy_unet, scheduler=scheduler) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + generator = torch.Generator(device=device).manual_seed(0) + output = pipe(generator=generator, num_inference_steps=4) + audio = output.audios + + generator = torch.Generator(device=device).manual_seed(0) + output = pipe(generator=generator, num_inference_steps=4, return_dict=False) + audio_from_tuple = output[0] + + audio_slice = audio[0, -3:, -3:] + audio_from_tuple_slice = audio_from_tuple[0, -3:, -3:] + + assert audio.shape == (1, 2, self.dummy_unet.sample_size) + expected_slice = np.array([-0.7265, 1.0000, -0.8388, 0.1175, 0.9498, -1.0000]) + assert np.abs(audio_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(audio_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + + +@slow +@require_torch_gpu +class PipelineIntegrationTests(unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_dance_diffusion(self): + device = torch_device + + pipe = DanceDiffusionPipeline.from_pretrained("harmonai/maestro-150k") + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + generator = torch.Generator(device=device).manual_seed(0) + output = pipe(generator=generator, num_inference_steps=100, audio_length_in_s=4.096) + audio = output.audios + + audio_slice = audio[0, -3:, -3:] + + assert audio.shape == (1, 2, pipe.unet.sample_size) + expected_slice = np.array([-0.1576, -0.1526, -0.127, -0.2699, -0.2762, -0.2487]) + assert np.abs(audio_slice.flatten() - expected_slice).max() < 1e-2 + + def test_dance_diffusion_fp16(self): + device = torch_device + + pipe = DanceDiffusionPipeline.from_pretrained("harmonai/maestro-150k", torch_dtype=torch.float16) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + generator = torch.Generator(device=device).manual_seed(0) + output = pipe(generator=generator, num_inference_steps=100, audio_length_in_s=4.096) + audio = output.audios + + audio_slice = audio[0, -3:, -3:] + + assert audio.shape == (1, 2, pipe.unet.sample_size) + expected_slice = np.array([-0.1693, -0.1698, -0.1447, -0.3044, -0.3203, -0.2937]) + assert np.abs(audio_slice.flatten() - expected_slice).max() < 1e-2 diff --git a/tests/pipelines/ddim/__init__.py b/tests/pipelines/ddim/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/ddim/test_ddim.py b/tests/pipelines/ddim/test_ddim.py new file mode 100644 index 000000000000..2d03383599e0 --- /dev/null +++ b/tests/pipelines/ddim/test_ddim.py @@ -0,0 +1,110 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch + +from diffusers import DDIMPipeline, DDIMScheduler, UNet2DModel +from diffusers.utils.testing_utils import require_torch_gpu, slow, torch_device + +from ...test_pipelines_common import PipelineTesterMixin + + +torch.backends.cuda.matmul.allow_tf32 = False + + +class DDIMPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + @property + def dummy_uncond_unet(self): + torch.manual_seed(0) + model = UNet2DModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=3, + out_channels=3, + down_block_types=("DownBlock2D", "AttnDownBlock2D"), + up_block_types=("AttnUpBlock2D", "UpBlock2D"), + ) + return model + + def test_inference(self): + device = "cpu" + unet = self.dummy_uncond_unet + scheduler = DDIMScheduler() + + ddpm = DDIMPipeline(unet=unet, scheduler=scheduler) + ddpm.to(device) + ddpm.set_progress_bar_config(disable=None) + + generator = torch.Generator(device=device).manual_seed(0) + image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images + + generator = torch.Generator(device=device).manual_seed(0) + image_from_tuple = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", return_dict=False)[0] + + image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + + assert image.shape == (1, 32, 32, 3) + expected_slice = np.array( + [1.000e00, 5.717e-01, 4.717e-01, 1.000e00, 0.000e00, 1.000e00, 3.000e-04, 0.000e00, 9.000e-04] + ) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + + +@slow +@require_torch_gpu +class DDIMPipelineIntegrationTests(unittest.TestCase): + def test_inference_ema_bedroom(self): + model_id = "google/ddpm-ema-bedroom-256" + + unet = UNet2DModel.from_pretrained(model_id) + scheduler = DDIMScheduler.from_pretrained(model_id) + + ddpm = DDIMPipeline(unet=unet, scheduler=scheduler) + ddpm.to(torch_device) + ddpm.set_progress_bar_config(disable=None) + + generator = torch.Generator(device=torch_device).manual_seed(0) + image = ddpm(generator=generator, output_type="numpy").images + + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 256, 256, 3) + expected_slice = np.array([0.1546, 0.1561, 0.1595, 0.1564, 0.1569, 0.1585, 0.1554, 0.1550, 0.1575]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + def test_inference_cifar10(self): + model_id = "google/ddpm-cifar10-32" + + unet = UNet2DModel.from_pretrained(model_id) + scheduler = DDIMScheduler() + + ddim = DDIMPipeline(unet=unet, scheduler=scheduler) + ddim.to(torch_device) + ddim.set_progress_bar_config(disable=None) + + generator = torch.Generator(device=torch_device).manual_seed(0) + image = ddim(generator=generator, eta=0.0, output_type="numpy").images + + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 32, 32, 3) + expected_slice = np.array([0.2060, 0.2042, 0.2022, 0.2193, 0.2146, 0.2110, 0.2471, 0.2446, 0.2388]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 diff --git a/tests/pipelines/ddpm/__init__.py b/tests/pipelines/ddpm/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/ddpm/test_ddpm.py b/tests/pipelines/ddpm/test_ddpm.py new file mode 100644 index 000000000000..6656fb738d51 --- /dev/null +++ b/tests/pipelines/ddpm/test_ddpm.py @@ -0,0 +1,151 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch + +from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel +from diffusers.utils import deprecate +from diffusers.utils.testing_utils import require_torch_gpu, slow, torch_device + +from ...test_pipelines_common import PipelineTesterMixin + + +torch.backends.cuda.matmul.allow_tf32 = False + + +class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + @property + def dummy_uncond_unet(self): + torch.manual_seed(0) + model = UNet2DModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=3, + out_channels=3, + down_block_types=("DownBlock2D", "AttnDownBlock2D"), + up_block_types=("AttnUpBlock2D", "UpBlock2D"), + ) + return model + + def test_inference(self): + device = "cpu" + unet = self.dummy_uncond_unet + scheduler = DDPMScheduler() + + ddpm = DDPMPipeline(unet=unet, scheduler=scheduler) + ddpm.to(device) + ddpm.set_progress_bar_config(disable=None) + + generator = torch.Generator(device=device).manual_seed(0) + image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images + + generator = torch.Generator(device=device).manual_seed(0) + image_from_tuple = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", return_dict=False)[0] + + image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + + assert image.shape == (1, 32, 32, 3) + expected_slice = np.array( + [5.589e-01, 7.089e-01, 2.632e-01, 6.841e-01, 1.000e-04, 9.999e-01, 1.973e-01, 1.000e-04, 8.010e-02] + ) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + + def test_inference_deprecated_predict_epsilon(self): + deprecate("remove this test", "0.10.0", "remove") + unet = self.dummy_uncond_unet + scheduler = DDPMScheduler(predict_epsilon=False) + + ddpm = DDPMPipeline(unet=unet, scheduler=scheduler) + ddpm.to(torch_device) + ddpm.set_progress_bar_config(disable=None) + + # Warmup pass when using mps (see #372) + if torch_device == "mps": + _ = ddpm(num_inference_steps=1) + + if torch_device == "mps": + # device type MPS is not supported for torch.Generator() api. + generator = torch.manual_seed(0) + else: + generator = torch.Generator(device=torch_device).manual_seed(0) + image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images + + generator = generator.manual_seed(0) + image_eps = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", predict_epsilon=False)[0] + + image_slice = image[0, -3:, -3:, -1] + image_eps_slice = image_eps[0, -3:, -3:, -1] + + assert image.shape == (1, 32, 32, 3) + tolerance = 1e-2 if torch_device != "mps" else 3e-2 + assert np.abs(image_slice.flatten() - image_eps_slice.flatten()).max() < tolerance + + def test_inference_predict_sample(self): + unet = self.dummy_uncond_unet + scheduler = DDPMScheduler(prediction_type="sample") + + ddpm = DDPMPipeline(unet=unet, scheduler=scheduler) + ddpm.to(torch_device) + ddpm.set_progress_bar_config(disable=None) + + # Warmup pass when using mps (see #372) + if torch_device == "mps": + _ = ddpm(num_inference_steps=1) + + if torch_device == "mps": + # device type MPS is not supported for torch.Generator() api. + generator = torch.manual_seed(0) + else: + generator = torch.Generator(device=torch_device).manual_seed(0) + image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images + + generator = generator.manual_seed(0) + image_eps = ddpm(generator=generator, num_inference_steps=2, output_type="numpy")[0] + + image_slice = image[0, -3:, -3:, -1] + image_eps_slice = image_eps[0, -3:, -3:, -1] + + assert image.shape == (1, 32, 32, 3) + tolerance = 1e-2 if torch_device != "mps" else 3e-2 + assert np.abs(image_slice.flatten() - image_eps_slice.flatten()).max() < tolerance + + +@slow +@require_torch_gpu +class DDPMPipelineIntegrationTests(unittest.TestCase): + def test_inference_cifar10(self): + model_id = "google/ddpm-cifar10-32" + + unet = UNet2DModel.from_pretrained(model_id) + scheduler = DDPMScheduler.from_pretrained(model_id) + + ddpm = DDPMPipeline(unet=unet, scheduler=scheduler) + ddpm.to(torch_device) + ddpm.set_progress_bar_config(disable=None) + + generator = torch.Generator(device=torch_device).manual_seed(0) + image = ddpm(generator=generator, output_type="numpy").images + + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 32, 32, 3) + expected_slice = np.array([0.4454, 0.2025, 0.0315, 0.3023, 0.2575, 0.1031, 0.0953, 0.1604, 0.2020]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 diff --git a/tests/pipelines/karras_ve/__init__.py b/tests/pipelines/karras_ve/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/karras_ve/test_karras_ve.py b/tests/pipelines/karras_ve/test_karras_ve.py new file mode 100644 index 000000000000..1fafa1cb40fa --- /dev/null +++ b/tests/pipelines/karras_ve/test_karras_ve.py @@ -0,0 +1,86 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch + +from diffusers import KarrasVePipeline, KarrasVeScheduler, UNet2DModel +from diffusers.utils.testing_utils import require_torch, slow, torch_device + +from ...test_pipelines_common import PipelineTesterMixin + + +torch.backends.cuda.matmul.allow_tf32 = False + + +class KarrasVePipelineFastTests(PipelineTesterMixin, unittest.TestCase): + @property + def dummy_uncond_unet(self): + torch.manual_seed(0) + model = UNet2DModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=3, + out_channels=3, + down_block_types=("DownBlock2D", "AttnDownBlock2D"), + up_block_types=("AttnUpBlock2D", "UpBlock2D"), + ) + return model + + def test_inference(self): + unet = self.dummy_uncond_unet + scheduler = KarrasVeScheduler() + + pipe = KarrasVePipeline(unet=unet, scheduler=scheduler) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + generator = torch.manual_seed(0) + image = pipe(num_inference_steps=2, generator=generator, output_type="numpy").images + + generator = torch.manual_seed(0) + image_from_tuple = pipe(num_inference_steps=2, generator=generator, output_type="numpy", return_dict=False)[0] + + image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + + assert image.shape == (1, 32, 32, 3) + expected_slice = np.array([0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + + +@slow +@require_torch +class KarrasVePipelineIntegrationTests(unittest.TestCase): + def test_inference(self): + model_id = "google/ncsnpp-celebahq-256" + model = UNet2DModel.from_pretrained(model_id) + scheduler = KarrasVeScheduler() + + pipe = KarrasVePipeline(unet=model, scheduler=scheduler) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + generator = torch.manual_seed(0) + image = pipe(num_inference_steps=20, generator=generator, output_type="numpy").images + + image_slice = image[0, -3:, -3:, -1] + assert image.shape == (1, 256, 256, 3) + expected_slice = np.array([0.578, 0.5811, 0.5924, 0.5809, 0.587, 0.5886, 0.5861, 0.5802, 0.586]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 diff --git a/tests/pipelines/latent_diffusion/__init__.py b/tests/pipelines/latent_diffusion/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/latent_diffusion/test_latent_diffusion.py b/tests/pipelines/latent_diffusion/test_latent_diffusion.py new file mode 100644 index 000000000000..9d5c07809dc0 --- /dev/null +++ b/tests/pipelines/latent_diffusion/test_latent_diffusion.py @@ -0,0 +1,153 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch + +from diffusers import AutoencoderKL, DDIMScheduler, LDMTextToImagePipeline, UNet2DConditionModel +from diffusers.utils.testing_utils import require_torch, slow, torch_device +from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer + +from ...test_pipelines_common import PipelineTesterMixin + + +torch.backends.cuda.matmul.allow_tf32 = False + + +class LDMTextToImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase): + @property + def dummy_cond_unet(self): + torch.manual_seed(0) + model = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=32, + ) + return model + + @property + def dummy_vae(self): + torch.manual_seed(0) + model = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + ) + return model + + @property + def dummy_text_encoder(self): + torch.manual_seed(0) + config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + ) + return CLIPTextModel(config) + + def test_inference_text2img(self): + unet = self.dummy_cond_unet + scheduler = DDIMScheduler() + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + ldm = LDMTextToImagePipeline(vqvae=vae, bert=bert, tokenizer=tokenizer, unet=unet, scheduler=scheduler) + ldm.to(torch_device) + ldm.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + + # Warmup pass when using mps (see #372) + if torch_device == "mps": + generator = torch.manual_seed(0) + _ = ldm( + [prompt], generator=generator, guidance_scale=6.0, num_inference_steps=1, output_type="numpy" + ).images + + generator = torch.manual_seed(0) + image = ldm( + [prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="numpy" + ).images + + generator = torch.manual_seed(0) + image_from_tuple = ldm( + [prompt], + generator=generator, + guidance_scale=6.0, + num_inference_steps=2, + output_type="numpy", + return_dict=False, + )[0] + + image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + + assert image.shape == (1, 16, 16, 3) + expected_slice = np.array([0.6806, 0.5454, 0.5638, 0.4893, 0.4656, 0.4257, 0.6248, 0.5217, 0.5498]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + + +@slow +@require_torch +class LDMTextToImagePipelineIntegrationTests(unittest.TestCase): + def test_inference_text2img(self): + ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256") + ldm.to(torch_device) + ldm.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.manual_seed(0) + image = ldm( + [prompt], generator=generator, guidance_scale=6.0, num_inference_steps=20, output_type="numpy" + ).images + + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 256, 256, 3) + expected_slice = np.array([0.9256, 0.9340, 0.8933, 0.9361, 0.9113, 0.8727, 0.9122, 0.8745, 0.8099]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + def test_inference_text2img_fast(self): + ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256") + ldm.to(torch_device) + ldm.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.manual_seed(0) + image = ldm(prompt, generator=generator, num_inference_steps=1, output_type="numpy").images + + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 256, 256, 3) + expected_slice = np.array([0.3163, 0.8670, 0.6465, 0.1865, 0.6291, 0.5139, 0.2824, 0.3723, 0.4344]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 diff --git a/tests/pipelines/latent_diffusion/test_latent_diffusion_superresolution.py b/tests/pipelines/latent_diffusion/test_latent_diffusion_superresolution.py new file mode 100644 index 000000000000..6f1f51c7ba75 --- /dev/null +++ b/tests/pipelines/latent_diffusion/test_latent_diffusion_superresolution.py @@ -0,0 +1,133 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import unittest + +import numpy as np +import torch + +from diffusers import DDIMScheduler, LDMSuperResolutionPipeline, UNet2DModel, VQModel +from diffusers.utils import PIL_INTERPOLATION, floats_tensor, load_image, slow, torch_device +from diffusers.utils.testing_utils import require_torch + +from ...test_pipelines_common import PipelineTesterMixin + + +torch.backends.cuda.matmul.allow_tf32 = False + + +class LDMSuperResolutionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + @property + def dummy_image(self): + batch_size = 1 + num_channels = 3 + sizes = (32, 32) + + image = floats_tensor((batch_size, num_channels) + sizes, rng=random.Random(0)).to(torch_device) + return image + + @property + def dummy_uncond_unet(self): + torch.manual_seed(0) + model = UNet2DModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=6, + out_channels=3, + down_block_types=("DownBlock2D", "AttnDownBlock2D"), + up_block_types=("AttnUpBlock2D", "UpBlock2D"), + ) + return model + + @property + def dummy_vq_model(self): + torch.manual_seed(0) + model = VQModel( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=3, + ) + return model + + def test_inference_superresolution(self): + device = "cpu" + unet = self.dummy_uncond_unet + scheduler = DDIMScheduler() + vqvae = self.dummy_vq_model + + ldm = LDMSuperResolutionPipeline(unet=unet, vqvae=vqvae, scheduler=scheduler) + ldm.to(device) + ldm.set_progress_bar_config(disable=None) + + init_image = self.dummy_image.to(device) + + generator = torch.Generator(device=device).manual_seed(0) + image = ldm(init_image, generator=generator, num_inference_steps=2, output_type="numpy").images + + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + expected_slice = np.array([0.8678, 0.8245, 0.6381, 0.6830, 0.4385, 0.5599, 0.4641, 0.6201, 0.5150]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + @unittest.skipIf(torch_device != "cuda", "This test requires a GPU") + def test_inference_superresolution_fp16(self): + unet = self.dummy_uncond_unet + scheduler = DDIMScheduler() + vqvae = self.dummy_vq_model + + # put models in fp16 + unet = unet.half() + vqvae = vqvae.half() + + ldm = LDMSuperResolutionPipeline(unet=unet, vqvae=vqvae, scheduler=scheduler) + ldm.to(torch_device) + ldm.set_progress_bar_config(disable=None) + + init_image = self.dummy_image.to(torch_device) + + generator = torch.Generator(device=torch_device).manual_seed(0) + image = ldm(init_image, generator=generator, num_inference_steps=2, output_type="numpy").images + + assert image.shape == (1, 64, 64, 3) + + +@slow +@require_torch +class LDMSuperResolutionPipelineIntegrationTests(unittest.TestCase): + def test_inference_superresolution(self): + init_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/vq_diffusion/teddy_bear_pool.png" + ) + init_image = init_image.resize((64, 64), resample=PIL_INTERPOLATION["lanczos"]) + + ldm = LDMSuperResolutionPipeline.from_pretrained("duongna/ldm-super-resolution", device_map="auto") + ldm.to(torch_device) + ldm.set_progress_bar_config(disable=None) + + generator = torch.Generator(device=torch_device).manual_seed(0) + image = ldm(init_image, generator=generator, num_inference_steps=20, output_type="numpy").images + + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 256, 256, 3) + expected_slice = np.array([0.7418, 0.7472, 0.7424, 0.7422, 0.7463, 0.726, 0.7382, 0.7248, 0.6828]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 diff --git a/tests/pipelines/latent_diffusion/test_latent_diffusion_uncond.py b/tests/pipelines/latent_diffusion/test_latent_diffusion_uncond.py new file mode 100644 index 000000000000..dea2971cbb95 --- /dev/null +++ b/tests/pipelines/latent_diffusion/test_latent_diffusion_uncond.py @@ -0,0 +1,119 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch + +from diffusers import DDIMScheduler, LDMPipeline, UNet2DModel, VQModel +from diffusers.utils.testing_utils import require_torch, slow, torch_device +from transformers import CLIPTextConfig, CLIPTextModel + +from ...test_pipelines_common import PipelineTesterMixin + + +torch.backends.cuda.matmul.allow_tf32 = False + + +class LDMPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + @property + def dummy_uncond_unet(self): + torch.manual_seed(0) + model = UNet2DModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=3, + out_channels=3, + down_block_types=("DownBlock2D", "AttnDownBlock2D"), + up_block_types=("AttnUpBlock2D", "UpBlock2D"), + ) + return model + + @property + def dummy_vq_model(self): + torch.manual_seed(0) + model = VQModel( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=3, + ) + return model + + @property + def dummy_text_encoder(self): + torch.manual_seed(0) + config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + ) + return CLIPTextModel(config) + + def test_inference_uncond(self): + unet = self.dummy_uncond_unet + scheduler = DDIMScheduler() + vae = self.dummy_vq_model + + ldm = LDMPipeline(unet=unet, vqvae=vae, scheduler=scheduler) + ldm.to(torch_device) + ldm.set_progress_bar_config(disable=None) + + # Warmup pass when using mps (see #372) + if torch_device == "mps": + generator = torch.manual_seed(0) + _ = ldm(generator=generator, num_inference_steps=1, output_type="numpy").images + + generator = torch.manual_seed(0) + image = ldm(generator=generator, num_inference_steps=2, output_type="numpy").images + + generator = torch.manual_seed(0) + image_from_tuple = ldm(generator=generator, num_inference_steps=2, output_type="numpy", return_dict=False)[0] + + image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + expected_slice = np.array([0.8512, 0.818, 0.6411, 0.6808, 0.4465, 0.5618, 0.46, 0.6231, 0.5172]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + + +@slow +@require_torch +class LDMPipelineIntegrationTests(unittest.TestCase): + def test_inference_uncond(self): + ldm = LDMPipeline.from_pretrained("CompVis/ldm-celebahq-256") + ldm.to(torch_device) + ldm.set_progress_bar_config(disable=None) + + generator = torch.manual_seed(0) + image = ldm(generator=generator, num_inference_steps=5, output_type="numpy").images + + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 256, 256, 3) + expected_slice = np.array([0.4399, 0.44975, 0.46825, 0.474, 0.4359, 0.4581, 0.45095, 0.4341, 0.4447]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 diff --git a/tests/pipelines/pndm/__init__.py b/tests/pipelines/pndm/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/pndm/test_pndm.py b/tests/pipelines/pndm/test_pndm.py new file mode 100644 index 000000000000..5d9212223e6e --- /dev/null +++ b/tests/pipelines/pndm/test_pndm.py @@ -0,0 +1,87 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch + +from diffusers import PNDMPipeline, PNDMScheduler, UNet2DModel +from diffusers.utils.testing_utils import require_torch, slow, torch_device + +from ...test_pipelines_common import PipelineTesterMixin + + +torch.backends.cuda.matmul.allow_tf32 = False + + +class PNDMPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + @property + def dummy_uncond_unet(self): + torch.manual_seed(0) + model = UNet2DModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=3, + out_channels=3, + down_block_types=("DownBlock2D", "AttnDownBlock2D"), + up_block_types=("AttnUpBlock2D", "UpBlock2D"), + ) + return model + + def test_inference(self): + unet = self.dummy_uncond_unet + scheduler = PNDMScheduler() + + pndm = PNDMPipeline(unet=unet, scheduler=scheduler) + pndm.to(torch_device) + pndm.set_progress_bar_config(disable=None) + + generator = torch.manual_seed(0) + image = pndm(generator=generator, num_inference_steps=20, output_type="numpy").images + + generator = torch.manual_seed(0) + image_from_tuple = pndm(generator=generator, num_inference_steps=20, output_type="numpy", return_dict=False)[0] + + image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + + assert image.shape == (1, 32, 32, 3) + expected_slice = np.array([1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + + +@slow +@require_torch +class PNDMPipelineIntegrationTests(unittest.TestCase): + def test_inference_cifar10(self): + model_id = "google/ddpm-cifar10-32" + + unet = UNet2DModel.from_pretrained(model_id) + scheduler = PNDMScheduler() + + pndm = PNDMPipeline(unet=unet, scheduler=scheduler) + pndm.to(torch_device) + pndm.set_progress_bar_config(disable=None) + generator = torch.manual_seed(0) + image = pndm(generator=generator, output_type="numpy").images + + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 32, 32, 3) + expected_slice = np.array([0.1564, 0.14645, 0.1406, 0.14715, 0.12425, 0.14045, 0.13115, 0.12175, 0.125]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 diff --git a/tests/pipelines/repaint/__init__.py b/tests/pipelines/repaint/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/repaint/test_repaint.py b/tests/pipelines/repaint/test_repaint.py new file mode 100644 index 000000000000..3ab0efc875f1 --- /dev/null +++ b/tests/pipelines/repaint/test_repaint.py @@ -0,0 +1,65 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch + +from diffusers import RePaintPipeline, RePaintScheduler, UNet2DModel +from diffusers.utils.testing_utils import load_image, require_torch_gpu, slow, torch_device + + +torch.backends.cuda.matmul.allow_tf32 = False + + +@slow +@require_torch_gpu +class RepaintPipelineIntegrationTests(unittest.TestCase): + def test_celebahq(self): + original_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/" + "repaint/celeba_hq_256.png" + ) + mask_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/repaint/mask_256.png" + ) + expected_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/" + "repaint/celeba_hq_256_result.png" + ) + expected_image = np.array(expected_image, dtype=np.float32) / 255.0 + + model_id = "google/ddpm-ema-celebahq-256" + unet = UNet2DModel.from_pretrained(model_id) + scheduler = RePaintScheduler.from_pretrained(model_id) + + repaint = RePaintPipeline(unet=unet, scheduler=scheduler).to(torch_device) + + generator = torch.Generator(device=torch_device).manual_seed(0) + output = repaint( + original_image, + mask_image, + num_inference_steps=250, + eta=0.0, + jump_length=10, + jump_n_sample=10, + generator=generator, + output_type="np", + ) + image = output.images[0] + + assert image.shape == (256, 256, 3) + assert np.abs(expected_image - image).mean() < 1e-2 diff --git a/tests/pipelines/score_sde_ve/__init__.py b/tests/pipelines/score_sde_ve/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/score_sde_ve/test_score_sde_ve.py b/tests/pipelines/score_sde_ve/test_score_sde_ve.py new file mode 100644 index 000000000000..9cdf3f0191e1 --- /dev/null +++ b/tests/pipelines/score_sde_ve/test_score_sde_ve.py @@ -0,0 +1,91 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch + +from diffusers import ScoreSdeVePipeline, ScoreSdeVeScheduler, UNet2DModel +from diffusers.utils.testing_utils import require_torch, slow, torch_device + +from ...test_pipelines_common import PipelineTesterMixin + + +torch.backends.cuda.matmul.allow_tf32 = False + + +class ScoreSdeVeipelineFastTests(PipelineTesterMixin, unittest.TestCase): + @property + def dummy_uncond_unet(self): + torch.manual_seed(0) + model = UNet2DModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=3, + out_channels=3, + down_block_types=("DownBlock2D", "AttnDownBlock2D"), + up_block_types=("AttnUpBlock2D", "UpBlock2D"), + ) + return model + + def test_inference(self): + unet = self.dummy_uncond_unet + scheduler = ScoreSdeVeScheduler() + + sde_ve = ScoreSdeVePipeline(unet=unet, scheduler=scheduler) + sde_ve.to(torch_device) + sde_ve.set_progress_bar_config(disable=None) + + generator = torch.manual_seed(0) + image = sde_ve(num_inference_steps=2, output_type="numpy", generator=generator).images + + generator = torch.manual_seed(0) + image_from_tuple = sde_ve(num_inference_steps=2, output_type="numpy", generator=generator, return_dict=False)[ + 0 + ] + + image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + + assert image.shape == (1, 32, 32, 3) + expected_slice = np.array([0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + + +@slow +@require_torch +class ScoreSdeVePipelineIntegrationTests(unittest.TestCase): + def test_inference(self): + model_id = "google/ncsnpp-church-256" + model = UNet2DModel.from_pretrained(model_id) + + scheduler = ScoreSdeVeScheduler.from_pretrained(model_id) + + sde_ve = ScoreSdeVePipeline(unet=model, scheduler=scheduler) + sde_ve.to(torch_device) + sde_ve.set_progress_bar_config(disable=None) + + generator = torch.manual_seed(0) + image = sde_ve(num_inference_steps=10, output_type="numpy", generator=generator).images + + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 256, 256, 3) + + expected_slice = np.array([0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 diff --git a/tests/pipelines/stable_diffusion/__init__.py b/tests/pipelines/stable_diffusion/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/stable_diffusion/test_cycle_diffusion.py b/tests/pipelines/stable_diffusion/test_cycle_diffusion.py new file mode 100644 index 000000000000..7a32b74096c4 --- /dev/null +++ b/tests/pipelines/stable_diffusion/test_cycle_diffusion.py @@ -0,0 +1,350 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import random +import unittest + +import numpy as np +import torch + +from diffusers import AutoencoderKL, CycleDiffusionPipeline, DDIMScheduler, UNet2DConditionModel, UNet2DModel, VQModel +from diffusers.utils import floats_tensor, load_image, load_numpy, slow, torch_device +from diffusers.utils.testing_utils import require_torch_gpu +from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer + +from ...test_pipelines_common import PipelineTesterMixin + + +torch.backends.cuda.matmul.allow_tf32 = False + + +class CycleDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + @property + def dummy_image(self): + batch_size = 1 + num_channels = 3 + sizes = (32, 32) + + image = floats_tensor((batch_size, num_channels) + sizes, rng=random.Random(0)).to(torch_device) + return image + + @property + def dummy_uncond_unet(self): + torch.manual_seed(0) + model = UNet2DModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=3, + out_channels=3, + down_block_types=("DownBlock2D", "AttnDownBlock2D"), + up_block_types=("AttnUpBlock2D", "UpBlock2D"), + ) + return model + + @property + def dummy_cond_unet(self): + torch.manual_seed(0) + model = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=32, + ) + return model + + @property + def dummy_cond_unet_inpaint(self): + torch.manual_seed(0) + model = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=9, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=32, + ) + return model + + @property + def dummy_vq_model(self): + torch.manual_seed(0) + model = VQModel( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=3, + ) + return model + + @property + def dummy_vae(self): + torch.manual_seed(0) + model = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + ) + return model + + @property + def dummy_text_encoder(self): + torch.manual_seed(0) + config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + ) + return CLIPTextModel(config) + + @property + def dummy_extractor(self): + def extract(*args, **kwargs): + class Out: + def __init__(self): + self.pixel_values = torch.ones([0]) + + def to(self, device): + self.pixel_values.to(device) + return self + + return Out() + + return extract + + def test_stable_diffusion_cycle(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unet = self.dummy_cond_unet + scheduler = DDIMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + num_train_timesteps=1000, + clip_sample=False, + set_alpha_to_one=False, + ) + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + # make sure here that pndm scheduler skips prk + sd_pipe = CycleDiffusionPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=self.dummy_extractor, + ) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + source_prompt = "An astronaut riding a horse" + prompt = "An astronaut riding an elephant" + init_image = self.dummy_image.to(device) + + generator = torch.Generator(device=device).manual_seed(0) + output = sd_pipe( + prompt=prompt, + source_prompt=source_prompt, + generator=generator, + num_inference_steps=2, + init_image=init_image, + eta=0.1, + strength=0.8, + guidance_scale=3, + source_guidance_scale=1, + output_type="np", + ) + images = output.images + + image_slice = images[0, -3:, -3:, -1] + + assert images.shape == (1, 32, 32, 3) + expected_slice = np.array([0.4459, 0.4943, 0.4544, 0.6643, 0.5474, 0.4327, 0.5701, 0.5959, 0.5179]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + @unittest.skipIf(torch_device != "cuda", "This test requires a GPU") + def test_stable_diffusion_cycle_fp16(self): + unet = self.dummy_cond_unet + scheduler = DDIMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + num_train_timesteps=1000, + clip_sample=False, + set_alpha_to_one=False, + ) + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + unet = unet.half() + vae = vae.half() + bert = bert.half() + + # make sure here that pndm scheduler skips prk + sd_pipe = CycleDiffusionPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=self.dummy_extractor, + ) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + source_prompt = "An astronaut riding a horse" + prompt = "An astronaut riding an elephant" + init_image = self.dummy_image.to(torch_device) + + generator = torch.Generator(device=torch_device).manual_seed(0) + output = sd_pipe( + prompt=prompt, + source_prompt=source_prompt, + generator=generator, + num_inference_steps=2, + init_image=init_image, + eta=0.1, + strength=0.8, + guidance_scale=3, + source_guidance_scale=1, + output_type="np", + ) + images = output.images + + image_slice = images[0, -3:, -3:, -1] + + assert images.shape == (1, 32, 32, 3) + expected_slice = np.array([0.3506, 0.4543, 0.446, 0.4575, 0.5195, 0.4155, 0.5273, 0.518, 0.4116]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + +@slow +@require_torch_gpu +class CycleDiffusionPipelineIntegrationTests(unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_cycle_diffusion_pipeline_fp16(self): + init_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/cycle-diffusion/black_colored_car.png" + ) + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/cycle-diffusion/blue_colored_car_fp16.npy" + ) + init_image = init_image.resize((512, 512)) + + model_id = "CompVis/stable-diffusion-v1-4" + scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler") + pipe = CycleDiffusionPipeline.from_pretrained( + model_id, scheduler=scheduler, safety_checker=None, torch_dtype=torch.float16, revision="fp16" + ) + + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + source_prompt = "A black colored car" + prompt = "A blue colored car" + + generator = torch.Generator(device=torch_device).manual_seed(0) + output = pipe( + prompt=prompt, + source_prompt=source_prompt, + init_image=init_image, + num_inference_steps=100, + eta=0.1, + strength=0.85, + guidance_scale=3, + source_guidance_scale=1, + generator=generator, + output_type="np", + ) + image = output.images + + # the values aren't exactly equal, but the images look the same visually + assert np.abs(image - expected_image).max() < 5e-1 + + def test_cycle_diffusion_pipeline(self): + init_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/cycle-diffusion/black_colored_car.png" + ) + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/cycle-diffusion/blue_colored_car.npy" + ) + init_image = init_image.resize((512, 512)) + + model_id = "CompVis/stable-diffusion-v1-4" + scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler") + pipe = CycleDiffusionPipeline.from_pretrained(model_id, scheduler=scheduler, safety_checker=None) + + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + source_prompt = "A black colored car" + prompt = "A blue colored car" + + generator = torch.Generator(device=torch_device).manual_seed(0) + output = pipe( + prompt=prompt, + source_prompt=source_prompt, + init_image=init_image, + num_inference_steps=100, + eta=0.1, + strength=0.85, + guidance_scale=3, + source_guidance_scale=1, + generator=generator, + output_type="np", + ) + image = output.images + + assert np.abs(image - expected_image).max() < 1e-2 diff --git a/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion.py new file mode 100644 index 000000000000..a2b48d27e6e0 --- /dev/null +++ b/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion.py @@ -0,0 +1,191 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +import unittest + +import numpy as np + +from diffusers import DDIMScheduler, LMSDiscreteScheduler, OnnxStableDiffusionPipeline +from diffusers.utils.testing_utils import is_onnx_available, require_onnxruntime, require_torch_gpu, slow + +from ...test_pipelines_onnx_common import OnnxPipelineTesterMixin + + +if is_onnx_available(): + import onnxruntime as ort + + +class OnnxStableDiffusionPipelineFastTests(OnnxPipelineTesterMixin, unittest.TestCase): + # FIXME: add fast tests + pass + + +@slow +@require_onnxruntime +@require_torch_gpu +class OnnxStableDiffusionPipelineIntegrationTests(unittest.TestCase): + @property + def gpu_provider(self): + return ( + "CUDAExecutionProvider", + { + "gpu_mem_limit": "15000000000", # 15GB + "arena_extend_strategy": "kSameAsRequested", + }, + ) + + @property + def gpu_options(self): + options = ort.SessionOptions() + options.enable_mem_pattern = False + return options + + def test_inference_default_pndm(self): + # using the PNDM scheduler by default + sd_pipe = OnnxStableDiffusionPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", + revision="onnx", + provider=self.gpu_provider, + sess_options=self.gpu_options, + ) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + np.random.seed(0) + output = sd_pipe([prompt], guidance_scale=6.0, num_inference_steps=10, output_type="np") + image = output.images + + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 512, 512, 3) + expected_slice = np.array([0.0452, 0.0390, 0.0087, 0.0350, 0.0617, 0.0364, 0.0544, 0.0523, 0.0720]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + + def test_inference_ddim(self): + ddim_scheduler = DDIMScheduler.from_pretrained( + "runwayml/stable-diffusion-v1-5", subfolder="scheduler", revision="onnx" + ) + sd_pipe = OnnxStableDiffusionPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", + revision="onnx", + scheduler=ddim_scheduler, + provider=self.gpu_provider, + sess_options=self.gpu_options, + ) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "open neural network exchange" + generator = np.random.RandomState(0) + output = sd_pipe([prompt], guidance_scale=7.5, num_inference_steps=10, generator=generator, output_type="np") + image = output.images + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 512, 512, 3) + expected_slice = np.array([0.2867, 0.1974, 0.1481, 0.7294, 0.7251, 0.6667, 0.4194, 0.5642, 0.6486]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + + def test_inference_k_lms(self): + lms_scheduler = LMSDiscreteScheduler.from_pretrained( + "runwayml/stable-diffusion-v1-5", subfolder="scheduler", revision="onnx" + ) + sd_pipe = OnnxStableDiffusionPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", + revision="onnx", + scheduler=lms_scheduler, + provider=self.gpu_provider, + sess_options=self.gpu_options, + ) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "open neural network exchange" + generator = np.random.RandomState(0) + output = sd_pipe([prompt], guidance_scale=7.5, num_inference_steps=10, generator=generator, output_type="np") + image = output.images + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 512, 512, 3) + expected_slice = np.array([0.2306, 0.1959, 0.1593, 0.6549, 0.6394, 0.5408, 0.5065, 0.6010, 0.6161]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + + def test_intermediate_state(self): + number_of_steps = 0 + + def test_callback_fn(step: int, timestep: int, latents: np.ndarray) -> None: + test_callback_fn.has_been_called = True + nonlocal number_of_steps + number_of_steps += 1 + if step == 0: + assert latents.shape == (1, 4, 64, 64) + latents_slice = latents[0, -3:, -3:, -1] + expected_slice = np.array( + [-0.6772, -0.3835, -1.2456, 0.1905, -1.0974, 0.6967, -1.9353, 0.0178, 1.0167] + ) + assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3 + elif step == 5: + assert latents.shape == (1, 4, 64, 64) + latents_slice = latents[0, -3:, -3:, -1] + expected_slice = np.array( + [-0.3351, 0.2241, -0.1837, -0.2325, -0.6577, 0.3393, -0.0241, 0.5899, 1.3875] + ) + assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3 + + test_callback_fn.has_been_called = False + + pipe = OnnxStableDiffusionPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", + revision="onnx", + provider=self.gpu_provider, + sess_options=self.gpu_options, + ) + pipe.set_progress_bar_config(disable=None) + + prompt = "Andromeda galaxy in a bottle" + + generator = np.random.RandomState(0) + pipe( + prompt=prompt, + num_inference_steps=5, + guidance_scale=7.5, + generator=generator, + callback=test_callback_fn, + callback_steps=1, + ) + assert test_callback_fn.has_been_called + assert number_of_steps == 6 + + def test_stable_diffusion_no_safety_checker(self): + pipe = OnnxStableDiffusionPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", + revision="onnx", + provider=self.gpu_provider, + sess_options=self.gpu_options, + safety_checker=None, + ) + assert isinstance(pipe, OnnxStableDiffusionPipeline) + assert pipe.safety_checker is None + + image = pipe("example prompt", num_inference_steps=2).images[0] + assert image is not None + + # check that there's no error when saving a pipeline with one of the models being None + with tempfile.TemporaryDirectory() as tmpdirname: + pipe.save_pretrained(tmpdirname) + pipe = OnnxStableDiffusionPipeline.from_pretrained(tmpdirname) + + # sanity check that the pipeline still works + assert pipe.safety_checker is None + image = pipe("example prompt", num_inference_steps=2).images[0] + assert image is not None diff --git a/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_img2img.py b/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_img2img.py new file mode 100644 index 000000000000..91e4412425b4 --- /dev/null +++ b/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_img2img.py @@ -0,0 +1,126 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +from diffusers import LMSDiscreteScheduler, OnnxStableDiffusionImg2ImgPipeline +from diffusers.utils.testing_utils import is_onnx_available, load_image, require_onnxruntime, require_torch_gpu, slow + +from ...test_pipelines_onnx_common import OnnxPipelineTesterMixin + + +if is_onnx_available(): + import onnxruntime as ort + + +class OnnxStableDiffusionPipelineFastTests(OnnxPipelineTesterMixin, unittest.TestCase): + # FIXME: add fast tests + pass + + +@slow +@require_onnxruntime +@require_torch_gpu +class OnnxStableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase): + @property + def gpu_provider(self): + return ( + "CUDAExecutionProvider", + { + "gpu_mem_limit": "15000000000", # 15GB + "arena_extend_strategy": "kSameAsRequested", + }, + ) + + @property + def gpu_options(self): + options = ort.SessionOptions() + options.enable_mem_pattern = False + return options + + def test_inference_default_pndm(self): + init_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/img2img/sketch-mountains-input.jpg" + ) + init_image = init_image.resize((768, 512)) + # using the PNDM scheduler by default + pipe = OnnxStableDiffusionImg2ImgPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", + revision="onnx", + provider=self.gpu_provider, + sess_options=self.gpu_options, + ) + pipe.set_progress_bar_config(disable=None) + + prompt = "A fantasy landscape, trending on artstation" + + generator = np.random.RandomState(0) + output = pipe( + prompt=prompt, + init_image=init_image, + strength=0.75, + guidance_scale=7.5, + num_inference_steps=10, + generator=generator, + output_type="np", + ) + images = output.images + image_slice = images[0, 255:258, 383:386, -1] + + assert images.shape == (1, 512, 768, 3) + expected_slice = np.array([0.4909, 0.5059, 0.5372, 0.4623, 0.4876, 0.5049, 0.4820, 0.4956, 0.5019]) + # TODO: lower the tolerance after finding the cause of onnxruntime reproducibility issues + assert np.abs(image_slice.flatten() - expected_slice).max() < 2e-2 + + def test_inference_k_lms(self): + init_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/img2img/sketch-mountains-input.jpg" + ) + init_image = init_image.resize((768, 512)) + lms_scheduler = LMSDiscreteScheduler.from_pretrained( + "runwayml/stable-diffusion-v1-5", subfolder="scheduler", revision="onnx" + ) + pipe = OnnxStableDiffusionImg2ImgPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", + revision="onnx", + scheduler=lms_scheduler, + provider=self.gpu_provider, + sess_options=self.gpu_options, + ) + pipe.set_progress_bar_config(disable=None) + + prompt = "A fantasy landscape, trending on artstation" + + generator = np.random.RandomState(0) + output = pipe( + prompt=prompt, + init_image=init_image, + strength=0.75, + guidance_scale=7.5, + num_inference_steps=10, + generator=generator, + output_type="np", + ) + images = output.images + image_slice = images[0, 255:258, 383:386, -1] + + assert images.shape == (1, 512, 768, 3) + expected_slice = np.array([0.7950, 0.7923, 0.7903, 0.5516, 0.5501, 0.5476, 0.4965, 0.4933, 0.4910]) + # TODO: lower the tolerance after finding the cause of onnxruntime reproducibility issues + assert np.abs(image_slice.flatten() - expected_slice).max() < 2e-2 diff --git a/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_inpaint.py new file mode 100644 index 000000000000..507375bddbfd --- /dev/null +++ b/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_inpaint.py @@ -0,0 +1,129 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +from diffusers import LMSDiscreteScheduler, OnnxStableDiffusionInpaintPipeline +from diffusers.utils.testing_utils import is_onnx_available, load_image, require_onnxruntime, require_torch_gpu, slow + +from ...test_pipelines_onnx_common import OnnxPipelineTesterMixin + + +if is_onnx_available(): + import onnxruntime as ort + + +class OnnxStableDiffusionPipelineFastTests(OnnxPipelineTesterMixin, unittest.TestCase): + # FIXME: add fast tests + pass + + +@slow +@require_onnxruntime +@require_torch_gpu +class OnnxStableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase): + @property + def gpu_provider(self): + return ( + "CUDAExecutionProvider", + { + "gpu_mem_limit": "15000000000", # 15GB + "arena_extend_strategy": "kSameAsRequested", + }, + ) + + @property + def gpu_options(self): + options = ort.SessionOptions() + options.enable_mem_pattern = False + return options + + def test_inference_default_pndm(self): + init_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/in_paint/overture-creations-5sI6fQgYIuo.png" + ) + mask_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/in_paint/overture-creations-5sI6fQgYIuo_mask.png" + ) + pipe = OnnxStableDiffusionInpaintPipeline.from_pretrained( + "runwayml/stable-diffusion-inpainting", + revision="onnx", + provider=self.gpu_provider, + sess_options=self.gpu_options, + ) + pipe.set_progress_bar_config(disable=None) + + prompt = "A red cat sitting on a park bench" + + generator = np.random.RandomState(0) + output = pipe( + prompt=prompt, + image=init_image, + mask_image=mask_image, + guidance_scale=7.5, + num_inference_steps=10, + generator=generator, + output_type="np", + ) + images = output.images + image_slice = images[0, 255:258, 255:258, -1] + + assert images.shape == (1, 512, 512, 3) + expected_slice = np.array([0.2514, 0.3007, 0.3517, 0.1790, 0.2382, 0.3167, 0.1944, 0.2273, 0.2464]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + + def test_inference_k_lms(self): + init_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/in_paint/overture-creations-5sI6fQgYIuo.png" + ) + mask_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/in_paint/overture-creations-5sI6fQgYIuo_mask.png" + ) + lms_scheduler = LMSDiscreteScheduler.from_pretrained( + "runwayml/stable-diffusion-inpainting", subfolder="scheduler", revision="onnx" + ) + pipe = OnnxStableDiffusionInpaintPipeline.from_pretrained( + "runwayml/stable-diffusion-inpainting", + revision="onnx", + scheduler=lms_scheduler, + provider=self.gpu_provider, + sess_options=self.gpu_options, + ) + pipe.set_progress_bar_config(disable=None) + + prompt = "A red cat sitting on a park bench" + + generator = np.random.RandomState(0) + output = pipe( + prompt=prompt, + image=init_image, + mask_image=mask_image, + guidance_scale=7.5, + num_inference_steps=10, + generator=generator, + output_type="np", + ) + images = output.images + image_slice = images[0, 255:258, 255:258, -1] + + assert images.shape == (1, 512, 512, 3) + expected_slice = np.array([0.2520, 0.2743, 0.2643, 0.2641, 0.2517, 0.2650, 0.2498, 0.2688, 0.2529]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 diff --git a/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_inpaint_legacy.py b/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_inpaint_legacy.py new file mode 100644 index 000000000000..577023f7055c --- /dev/null +++ b/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_inpaint_legacy.py @@ -0,0 +1,95 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +from diffusers import OnnxStableDiffusionInpaintPipelineLegacy +from diffusers.utils.testing_utils import ( + is_onnx_available, + load_image, + load_numpy, + require_onnxruntime, + require_torch_gpu, + slow, +) + + +if is_onnx_available(): + import onnxruntime as ort + + +@slow +@require_onnxruntime +@require_torch_gpu +class StableDiffusionOnnxInpaintLegacyPipelineIntegrationTests(unittest.TestCase): + @property + def gpu_provider(self): + return ( + "CUDAExecutionProvider", + { + "gpu_mem_limit": "15000000000", # 15GB + "arena_extend_strategy": "kSameAsRequested", + }, + ) + + @property + def gpu_options(self): + options = ort.SessionOptions() + options.enable_mem_pattern = False + return options + + def test_inference(self): + init_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/in_paint/overture-creations-5sI6fQgYIuo.png" + ) + mask_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/in_paint/overture-creations-5sI6fQgYIuo_mask.png" + ) + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/in_paint/red_cat_sitting_on_a_park_bench_onnx.npy" + ) + + # using the PNDM scheduler by default + pipe = OnnxStableDiffusionInpaintPipelineLegacy.from_pretrained( + "CompVis/stable-diffusion-v1-4", + revision="onnx", + provider=self.gpu_provider, + sess_options=self.gpu_options, + ) + pipe.set_progress_bar_config(disable=None) + + prompt = "A red cat sitting on a park bench" + + generator = np.random.RandomState(0) + output = pipe( + prompt=prompt, + init_image=init_image, + mask_image=mask_image, + strength=0.75, + guidance_scale=7.5, + num_inference_steps=15, + generator=generator, + output_type="np", + ) + + image = output.images[0] + + assert image.shape == (512, 512, 3) + assert np.abs(expected_image - image).max() < 1e-2 diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py new file mode 100644 index 000000000000..e2e27a211d88 --- /dev/null +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py @@ -0,0 +1,1021 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import random +import tempfile +import time +import unittest + +import numpy as np +import torch + +from diffusers import ( + AutoencoderKL, + DDIMScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, + StableDiffusionPipeline, + UNet2DConditionModel, + UNet2DModel, + VQModel, + logging, +) +from diffusers.utils import floats_tensor, load_numpy, slow, torch_device +from diffusers.utils.testing_utils import CaptureLogger, require_torch_gpu +from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer + +from ...test_pipelines_common import PipelineTesterMixin + + +torch.backends.cuda.matmul.allow_tf32 = False + + +class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + @property + def dummy_image(self): + batch_size = 1 + num_channels = 3 + sizes = (32, 32) + + image = floats_tensor((batch_size, num_channels) + sizes, rng=random.Random(0)).to(torch_device) + return image + + @property + def dummy_uncond_unet(self): + torch.manual_seed(0) + model = UNet2DModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=3, + out_channels=3, + down_block_types=("DownBlock2D", "AttnDownBlock2D"), + up_block_types=("AttnUpBlock2D", "UpBlock2D"), + ) + return model + + @property + def dummy_cond_unet(self): + torch.manual_seed(0) + model = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=32, + ) + return model + + @property + def dummy_cond_unet_inpaint(self): + torch.manual_seed(0) + model = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=9, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=32, + ) + return model + + @property + def dummy_vq_model(self): + torch.manual_seed(0) + model = VQModel( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=3, + ) + return model + + @property + def dummy_vae(self): + torch.manual_seed(0) + model = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + ) + return model + + @property + def dummy_text_encoder(self): + torch.manual_seed(0) + config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + ) + return CLIPTextModel(config) + + @property + def dummy_extractor(self): + def extract(*args, **kwargs): + class Out: + def __init__(self): + self.pixel_values = torch.ones([0]) + + def to(self, device): + self.pixel_values.to(device) + return self + + return Out() + + return extract + + def test_stable_diffusion_ddim(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unet = self.dummy_cond_unet + scheduler = DDIMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + ) + + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=self.dummy_extractor, + ) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + + generator = torch.Generator(device=device).manual_seed(0) + output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np") + image = output.images + + generator = torch.Generator(device=device).manual_seed(0) + image_from_tuple = sd_pipe( + [prompt], + generator=generator, + guidance_scale=6.0, + num_inference_steps=2, + output_type="np", + return_dict=False, + )[0] + + image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + expected_slice = np.array( + [ + 0.5643956661224365, + 0.6017904281616211, + 0.4799129366874695, + 0.5267305374145508, + 0.5584856271743774, + 0.46413588523864746, + 0.5159522294998169, + 0.4963662028312683, + 0.47919973731040955, + ] + ) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + + def test_stable_diffusion_ddim_factor_8(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unet = self.dummy_cond_unet + scheduler = DDIMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + ) + + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=self.dummy_extractor, + ) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + + generator = torch.Generator(device=device).manual_seed(0) + output = sd_pipe( + [prompt], + generator=generator, + guidance_scale=6.0, + height=136, + width=136, + num_inference_steps=2, + output_type="np", + ) + image = output.images + + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 136, 136, 3) + expected_slice = np.array([0.5524, 0.5626, 0.6069, 0.4727, 0.386, 0.3995, 0.4613, 0.4328, 0.4269]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + def test_stable_diffusion_pndm(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unet = self.dummy_cond_unet + scheduler = PNDMScheduler(skip_prk_steps=True) + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=self.dummy_extractor, + ) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device=device).manual_seed(0) + output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np") + + image = output.images + + generator = torch.Generator(device=device).manual_seed(0) + image_from_tuple = sd_pipe( + [prompt], + generator=generator, + guidance_scale=6.0, + num_inference_steps=2, + output_type="np", + return_dict=False, + )[0] + + image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + expected_slice = np.array( + [ + 0.5094760060310364, + 0.5674174427986145, + 0.46675148606300354, + 0.5125715136528015, + 0.5696930289268494, + 0.4674668312072754, + 0.5277683734893799, + 0.4964486062526703, + 0.494540274143219, + ] + ) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + + def test_stable_diffusion_no_safety_checker(self): + pipe = StableDiffusionPipeline.from_pretrained( + "hf-internal-testing/tiny-stable-diffusion-lms-pipe", safety_checker=None + ) + assert isinstance(pipe, StableDiffusionPipeline) + assert isinstance(pipe.scheduler, LMSDiscreteScheduler) + assert pipe.safety_checker is None + + image = pipe("example prompt", num_inference_steps=2).images[0] + assert image is not None + + # check that there's no error when saving a pipeline with one of the models being None + with tempfile.TemporaryDirectory() as tmpdirname: + pipe.save_pretrained(tmpdirname) + pipe = StableDiffusionPipeline.from_pretrained(tmpdirname) + + # sanity check that the pipeline still works + assert pipe.safety_checker is None + image = pipe("example prompt", num_inference_steps=2).images[0] + assert image is not None + + def test_stable_diffusion_k_lms(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unet = self.dummy_cond_unet + scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear") + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=self.dummy_extractor, + ) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device=device).manual_seed(0) + output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np") + + image = output.images + + generator = torch.Generator(device=device).manual_seed(0) + image_from_tuple = sd_pipe( + [prompt], + generator=generator, + guidance_scale=6.0, + num_inference_steps=2, + output_type="np", + return_dict=False, + )[0] + + image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + expected_slice = np.array( + [ + 0.47082293033599854, + 0.5371589064598083, + 0.4562119245529175, + 0.5220914483070374, + 0.5733777284622192, + 0.4795039892196655, + 0.5465868711471558, + 0.5074326395988464, + 0.5042197108268738, + ] + ) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + + def test_stable_diffusion_k_euler_ancestral(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unet = self.dummy_cond_unet + scheduler = EulerAncestralDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear") + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=self.dummy_extractor, + ) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device=device).manual_seed(0) + output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np") + + image = output.images + + generator = torch.Generator(device=device).manual_seed(0) + image_from_tuple = sd_pipe( + [prompt], + generator=generator, + guidance_scale=6.0, + num_inference_steps=2, + output_type="np", + return_dict=False, + )[0] + + image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + expected_slice = np.array( + [ + 0.4707113206386566, + 0.5372191071510315, + 0.4563021957874298, + 0.5220003724098206, + 0.5734264850616455, + 0.4794946610927582, + 0.5463782548904419, + 0.5074145197868347, + 0.504422664642334, + ] + ) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + + def test_stable_diffusion_k_euler(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unet = self.dummy_cond_unet + scheduler = EulerDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear") + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=self.dummy_extractor, + ) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device=device).manual_seed(0) + output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np") + + image = output.images + + generator = torch.Generator(device=device).manual_seed(0) + image_from_tuple = sd_pipe( + [prompt], + generator=generator, + guidance_scale=6.0, + num_inference_steps=2, + output_type="np", + return_dict=False, + )[0] + + image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + expected_slice = np.array( + [ + 0.47082313895225525, + 0.5371587872505188, + 0.4562119245529175, + 0.5220913887023926, + 0.5733776688575745, + 0.47950395941734314, + 0.546586811542511, + 0.5074326992034912, + 0.5042197108268738, + ] + ) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + + def test_stable_diffusion_attention_chunk(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unet = self.dummy_cond_unet + scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear") + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=self.dummy_extractor, + ) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device=device).manual_seed(0) + output_1 = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np") + + # make sure chunking the attention yields the same result + sd_pipe.enable_attention_slicing(slice_size=1) + generator = torch.Generator(device=device).manual_seed(0) + output_2 = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np") + + assert np.abs(output_2.images.flatten() - output_1.images.flatten()).max() < 1e-4 + + def test_stable_diffusion_negative_prompt(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unet = self.dummy_cond_unet + scheduler = PNDMScheduler(skip_prk_steps=True) + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=self.dummy_extractor, + ) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + negative_prompt = "french fries" + generator = torch.Generator(device=device).manual_seed(0) + output = sd_pipe( + prompt, + negative_prompt=negative_prompt, + generator=generator, + guidance_scale=6.0, + num_inference_steps=2, + output_type="np", + ) + + image = output.images + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + expected_slice = np.array( + [ + 0.5108221173286438, + 0.5688379406929016, + 0.4685141146183014, + 0.5098261833190918, + 0.5657756328582764, + 0.4631010890007019, + 0.5226285457611084, + 0.49129390716552734, + 0.4899061322212219, + ] + ) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + def test_stable_diffusion_num_images_per_prompt(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unet = self.dummy_cond_unet + scheduler = PNDMScheduler(skip_prk_steps=True) + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=self.dummy_extractor, + ) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + + # test num_images_per_prompt=1 (default) + images = sd_pipe(prompt, num_inference_steps=2, output_type="np").images + + assert images.shape == (1, 64, 64, 3) + + # test num_images_per_prompt=1 (default) for batch of prompts + batch_size = 2 + images = sd_pipe([prompt] * batch_size, num_inference_steps=2, output_type="np").images + + assert images.shape == (batch_size, 64, 64, 3) + + # test num_images_per_prompt for single prompt + num_images_per_prompt = 2 + images = sd_pipe( + prompt, num_inference_steps=2, output_type="np", num_images_per_prompt=num_images_per_prompt + ).images + + assert images.shape == (num_images_per_prompt, 64, 64, 3) + + # test num_images_per_prompt for batch of prompts + batch_size = 2 + images = sd_pipe( + [prompt] * batch_size, num_inference_steps=2, output_type="np", num_images_per_prompt=num_images_per_prompt + ).images + + assert images.shape == (batch_size * num_images_per_prompt, 64, 64, 3) + + @unittest.skipIf(torch_device != "cuda", "This test requires a GPU") + def test_stable_diffusion_fp16(self): + """Test that stable diffusion works with fp16""" + unet = self.dummy_cond_unet + scheduler = PNDMScheduler(skip_prk_steps=True) + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + # put models in fp16 + unet = unet.half() + vae = vae.half() + bert = bert.half() + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=self.dummy_extractor, + ) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device=torch_device).manual_seed(0) + image = sd_pipe([prompt], generator=generator, num_inference_steps=2, output_type="np").images + + assert image.shape == (1, 64, 64, 3) + + def test_stable_diffusion_long_prompt(self): + unet = self.dummy_cond_unet + scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear") + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=self.dummy_extractor, + ) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + do_classifier_free_guidance = True + negative_prompt = None + num_images_per_prompt = 1 + logger = logging.get_logger("diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion") + + prompt = 25 * "@" + with CaptureLogger(logger) as cap_logger_3: + text_embeddings_3 = sd_pipe._encode_prompt( + prompt, torch_device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + prompt = 100 * "@" + with CaptureLogger(logger) as cap_logger: + text_embeddings = sd_pipe._encode_prompt( + prompt, torch_device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + negative_prompt = "Hello" + with CaptureLogger(logger) as cap_logger_2: + text_embeddings_2 = sd_pipe._encode_prompt( + prompt, torch_device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + assert text_embeddings_3.shape == text_embeddings_2.shape == text_embeddings.shape + assert text_embeddings.shape[1] == 77 + + assert cap_logger.out == cap_logger_2.out + # 100 - 77 + 1 (BOS token) + 1 (EOS token) = 25 + assert cap_logger.out.count("@") == 25 + assert cap_logger_3.out == "" + + def test_stable_diffusion_height_width_opt(self): + unet = self.dummy_cond_unet + scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear") + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=self.dummy_extractor, + ) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "hey" + + output = sd_pipe(prompt, number_of_steps=1, output_type="np") + image_shape = output.images[0].shape[:2] + assert image_shape == (64, 64) + + output = sd_pipe(prompt, number_of_steps=1, height=96, width=96, output_type="np") + image_shape = output.images[0].shape[:2] + assert image_shape == (96, 96) + + config = dict(sd_pipe.unet.config) + config["sample_size"] = 96 + sd_pipe.unet = UNet2DConditionModel.from_config(config).to(torch_device) + output = sd_pipe(prompt, number_of_steps=1, output_type="np") + image_shape = output.images[0].shape[:2] + assert image_shape == (192, 192) + + +@slow +@require_torch_gpu +class StableDiffusionPipelineIntegrationTests(unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_stable_diffusion(self): + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1") + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device=torch_device).manual_seed(0) + with torch.autocast("cuda"): + output = sd_pipe( + [prompt], generator=generator, guidance_scale=6.0, num_inference_steps=20, output_type="np" + ) + + image = output.images + + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 512, 512, 3) + expected_slice = np.array([0.8887, 0.915, 0.91, 0.894, 0.909, 0.912, 0.919, 0.925, 0.883]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + def test_stable_diffusion_fast_ddim(self): + scheduler = DDIMScheduler.from_pretrained("CompVis/stable-diffusion-v1-1", subfolder="scheduler") + + sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1", scheduler=scheduler) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device=torch_device).manual_seed(0) + + with torch.autocast("cuda"): + output = sd_pipe([prompt], generator=generator, num_inference_steps=2, output_type="numpy") + image = output.images + + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 512, 512, 3) + expected_slice = np.array([0.9326, 0.923, 0.951, 0.9365, 0.9214, 0.951, 0.9365, 0.9414, 0.918]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + def test_lms_stable_diffusion_pipeline(self): + model_id = "CompVis/stable-diffusion-v1-1" + pipe = StableDiffusionPipeline.from_pretrained(model_id).to(torch_device) + pipe.set_progress_bar_config(disable=None) + scheduler = LMSDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler") + pipe.scheduler = scheduler + + prompt = "a photograph of an astronaut riding a horse" + generator = torch.Generator(device=torch_device).manual_seed(0) + image = pipe( + [prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy" + ).images + + image_slice = image[0, -3:, -3:, -1] + assert image.shape == (1, 512, 512, 3) + expected_slice = np.array([0.9077, 0.9254, 0.9181, 0.9227, 0.9213, 0.9367, 0.9399, 0.9406, 0.9024]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + def test_stable_diffusion_memory_chunking(self): + torch.cuda.reset_peak_memory_stats() + model_id = "CompVis/stable-diffusion-v1-4" + pipe = StableDiffusionPipeline.from_pretrained(model_id, revision="fp16", torch_dtype=torch.float16) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + prompt = "a photograph of an astronaut riding a horse" + + # make attention efficient + pipe.enable_attention_slicing() + generator = torch.Generator(device=torch_device).manual_seed(0) + with torch.autocast(torch_device): + output_chunked = pipe( + [prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy" + ) + image_chunked = output_chunked.images + + mem_bytes = torch.cuda.max_memory_allocated() + torch.cuda.reset_peak_memory_stats() + # make sure that less than 3.75 GB is allocated + assert mem_bytes < 3.75 * 10**9 + + # disable chunking + pipe.disable_attention_slicing() + generator = torch.Generator(device=torch_device).manual_seed(0) + with torch.autocast(torch_device): + output = pipe( + [prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy" + ) + image = output.images + + # make sure that more than 3.75 GB is allocated + mem_bytes = torch.cuda.max_memory_allocated() + assert mem_bytes > 3.75 * 10**9 + assert np.abs(image_chunked.flatten() - image.flatten()).max() < 1e-3 + + def test_stable_diffusion_text2img_pipeline_fp16(self): + torch.cuda.reset_peak_memory_stats() + model_id = "CompVis/stable-diffusion-v1-4" + pipe = StableDiffusionPipeline.from_pretrained(model_id, revision="fp16", torch_dtype=torch.float16) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + prompt = "a photograph of an astronaut riding a horse" + + generator = torch.Generator(device=torch_device).manual_seed(0) + output_chunked = pipe( + [prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy" + ) + image_chunked = output_chunked.images + + generator = torch.Generator(device=torch_device).manual_seed(0) + with torch.autocast(torch_device): + output = pipe( + [prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy" + ) + image = output.images + + # Make sure results are close enough + diff = np.abs(image_chunked.flatten() - image.flatten()) + # They ARE different since ops are not run always at the same precision + # however, they should be extremely close. + assert diff.mean() < 2e-2 + + def test_stable_diffusion_text2img_pipeline_default(self): + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/text2img/astronaut_riding_a_horse.npy" + ) + + model_id = "CompVis/stable-diffusion-v1-4" + pipe = StableDiffusionPipeline.from_pretrained(model_id, safety_checker=None) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + prompt = "astronaut riding a horse" + + generator = torch.Generator(device=torch_device).manual_seed(0) + output = pipe(prompt=prompt, strength=0.75, guidance_scale=7.5, generator=generator, output_type="np") + image = output.images[0] + + assert image.shape == (512, 512, 3) + assert np.abs(expected_image - image).max() < 5e-3 + + def test_stable_diffusion_text2img_intermediate_state(self): + number_of_steps = 0 + + def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> None: + test_callback_fn.has_been_called = True + nonlocal number_of_steps + number_of_steps += 1 + if step == 0: + latents = latents.detach().cpu().numpy() + assert latents.shape == (1, 4, 64, 64) + latents_slice = latents[0, -3:, -3:, -1] + expected_slice = np.array( + [1.8285, 1.2857, -0.1024, 1.2406, -2.3068, 1.0747, -0.0818, -0.6520, -2.9506] + ) + assert np.abs(latents_slice.flatten() - expected_slice).max() < 5e-3 + elif step == 50: + latents = latents.detach().cpu().numpy() + assert latents.shape == (1, 4, 64, 64) + latents_slice = latents[0, -3:, -3:, -1] + expected_slice = np.array( + [1.1078, 1.5803, 0.2773, -0.0589, -1.7928, -0.3665, -0.4695, -1.0727, -1.1601] + ) + assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-2 + + test_callback_fn.has_been_called = False + + pipe = StableDiffusionPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16 + ) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + prompt = "Andromeda galaxy in a bottle" + + generator = torch.Generator(device=torch_device).manual_seed(0) + with torch.autocast(torch_device): + pipe( + prompt=prompt, + num_inference_steps=50, + guidance_scale=7.5, + generator=generator, + callback=test_callback_fn, + callback_steps=1, + ) + assert test_callback_fn.has_been_called + assert number_of_steps == 51 + + def test_stable_diffusion_low_cpu_mem_usage(self): + pipeline_id = "CompVis/stable-diffusion-v1-4" + + start_time = time.time() + pipeline_low_cpu_mem_usage = StableDiffusionPipeline.from_pretrained( + pipeline_id, revision="fp16", torch_dtype=torch.float16 + ) + pipeline_low_cpu_mem_usage.to(torch_device) + low_cpu_mem_usage_time = time.time() - start_time + + start_time = time.time() + _ = StableDiffusionPipeline.from_pretrained( + pipeline_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True, low_cpu_mem_usage=False + ) + normal_load_time = time.time() - start_time + + assert 2 * low_cpu_mem_usage_time < normal_load_time + + def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self): + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + torch.cuda.reset_peak_memory_stats() + + pipeline_id = "CompVis/stable-diffusion-v1-4" + prompt = "Andromeda galaxy in a bottle" + + pipeline = StableDiffusionPipeline.from_pretrained(pipeline_id, revision="fp16", torch_dtype=torch.float16) + pipeline = pipeline.to(torch_device) + pipeline.enable_attention_slicing(1) + pipeline.enable_sequential_cpu_offload() + + generator = torch.Generator(device=torch_device).manual_seed(0) + _ = pipeline(prompt, generator=generator, num_inference_steps=5) + + mem_bytes = torch.cuda.max_memory_allocated() + # make sure that less than 2.8 GB is allocated + assert mem_bytes < 2.8 * 10**9 diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_image_variation.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_image_variation.py new file mode 100644 index 000000000000..9b350d42e1f0 --- /dev/null +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_image_variation.py @@ -0,0 +1,423 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import random +import unittest + +import numpy as np +import torch + +from diffusers import ( + AutoencoderKL, + LMSDiscreteScheduler, + PNDMScheduler, + StableDiffusionImageVariationPipeline, + UNet2DConditionModel, +) +from diffusers.utils import floats_tensor, load_image, load_numpy, slow, torch_device +from diffusers.utils.testing_utils import require_torch_gpu +from transformers import CLIPVisionConfig, CLIPVisionModelWithProjection + +from ...test_pipelines_common import PipelineTesterMixin + + +torch.backends.cuda.matmul.allow_tf32 = False + + +class StableDiffusionImageVariationPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + @property + def dummy_image(self): + batch_size = 1 + num_channels = 3 + sizes = (32, 32) + + image = floats_tensor((batch_size, num_channels) + sizes, rng=random.Random(0)).to(torch_device) + return image + + @property + def dummy_cond_unet(self): + torch.manual_seed(0) + model = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=32, + ) + return model + + @property + def dummy_vae(self): + torch.manual_seed(0) + model = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + ) + return model + + @property + def dummy_image_encoder(self): + torch.manual_seed(0) + config = CLIPVisionConfig( + hidden_size=32, + projection_dim=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + image_size=32, + patch_size=4, + ) + return CLIPVisionModelWithProjection(config) + + @property + def dummy_extractor(self): + def extract(*args, **kwargs): + class Out: + def __init__(self): + self.pixel_values = torch.ones([0]) + + def to(self, device): + self.pixel_values.to(device) + return self + + return Out() + + return extract + + def test_stable_diffusion_img_variation_default_case(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unet = self.dummy_cond_unet + scheduler = PNDMScheduler(skip_prk_steps=True) + vae = self.dummy_vae + image_encoder = self.dummy_image_encoder + + init_image = self.dummy_image.to(device) + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionImageVariationPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + image_encoder=image_encoder, + safety_checker=None, + feature_extractor=self.dummy_extractor, + ) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + generator = torch.Generator(device=device).manual_seed(0) + output = sd_pipe( + init_image, + generator=generator, + guidance_scale=6.0, + num_inference_steps=2, + output_type="np", + ) + + image = output.images + + generator = torch.Generator(device=device).manual_seed(0) + image_from_tuple = sd_pipe( + init_image, + generator=generator, + guidance_scale=6.0, + num_inference_steps=2, + output_type="np", + return_dict=False, + )[0] + + image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + expected_slice = np.array([0.5093, 0.5717, 0.4806, 0.4891, 0.5552, 0.4594, 0.5177, 0.4894, 0.4904]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-3 + + def test_stable_diffusion_img_variation_multiple_images(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unet = self.dummy_cond_unet + scheduler = PNDMScheduler(skip_prk_steps=True) + vae = self.dummy_vae + image_encoder = self.dummy_image_encoder + + init_image = self.dummy_image.to(device).repeat(2, 1, 1, 1) + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionImageVariationPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + image_encoder=image_encoder, + safety_checker=None, + feature_extractor=self.dummy_extractor, + ) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + generator = torch.Generator(device=device).manual_seed(0) + output = sd_pipe( + init_image, + generator=generator, + guidance_scale=6.0, + num_inference_steps=2, + output_type="np", + ) + + image = output.images + + image_slice = image[-1, -3:, -3:, -1] + + assert image.shape == (2, 64, 64, 3) + expected_slice = np.array([0.6427, 0.5452, 0.5602, 0.5478, 0.5968, 0.6211, 0.5538, 0.5514, 0.5281]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + + def test_stable_diffusion_img_variation_num_images_per_prompt(self): + device = "cpu" + unet = self.dummy_cond_unet + scheduler = PNDMScheduler(skip_prk_steps=True) + vae = self.dummy_vae + image_encoder = self.dummy_image_encoder + + init_image = self.dummy_image.to(device) + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionImageVariationPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + image_encoder=image_encoder, + safety_checker=None, + feature_extractor=self.dummy_extractor, + ) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + # test num_images_per_prompt=1 (default) + images = sd_pipe( + init_image, + num_inference_steps=2, + output_type="np", + ).images + + assert images.shape == (1, 64, 64, 3) + + # test num_images_per_prompt=1 (default) for batch of images + batch_size = 2 + images = sd_pipe( + init_image.repeat(batch_size, 1, 1, 1), + num_inference_steps=2, + output_type="np", + ).images + + assert images.shape == (batch_size, 64, 64, 3) + + # test num_images_per_prompt for single prompt + num_images_per_prompt = 2 + images = sd_pipe( + init_image, + num_inference_steps=2, + output_type="np", + num_images_per_prompt=num_images_per_prompt, + ).images + + assert images.shape == (num_images_per_prompt, 64, 64, 3) + + # test num_images_per_prompt for batch of prompts + batch_size = 2 + images = sd_pipe( + init_image.repeat(batch_size, 1, 1, 1), + num_inference_steps=2, + output_type="np", + num_images_per_prompt=num_images_per_prompt, + ).images + + assert images.shape == (batch_size * num_images_per_prompt, 64, 64, 3) + + @unittest.skipIf(torch_device != "cuda", "This test requires a GPU") + def test_stable_diffusion_img_variation_fp16(self): + """Test that stable diffusion img2img works with fp16""" + unet = self.dummy_cond_unet + scheduler = PNDMScheduler(skip_prk_steps=True) + vae = self.dummy_vae + image_encoder = self.dummy_image_encoder + + init_image = self.dummy_image.to(torch_device).float() + + # put models in fp16 + unet = unet.half() + vae = vae.half() + image_encoder = image_encoder.half() + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionImageVariationPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + image_encoder=image_encoder, + safety_checker=None, + feature_extractor=self.dummy_extractor, + ) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + generator = torch.Generator(device=torch_device).manual_seed(0) + image = sd_pipe( + init_image, + generator=generator, + num_inference_steps=2, + output_type="np", + ).images + + assert image.shape == (1, 64, 64, 3) + + +@slow +@require_torch_gpu +class StableDiffusionImageVariationPipelineIntegrationTests(unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_stable_diffusion_img_variation_pipeline_default(self): + init_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/img2img/vermeer.jpg" + ) + init_image = init_image.resize((512, 512)) + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/img2img/vermeer.npy" + ) + + model_id = "fusing/sd-image-variations-diffusers" + pipe = StableDiffusionImageVariationPipeline.from_pretrained( + model_id, + safety_checker=None, + ) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + generator = torch.Generator(device=torch_device).manual_seed(0) + output = pipe( + init_image, + guidance_scale=7.5, + generator=generator, + output_type="np", + ) + image = output.images[0] + + assert image.shape == (512, 512, 3) + # img2img is flaky across GPUs even in fp32, so using MAE here + assert np.abs(expected_image - image).max() < 1e-3 + + def test_stable_diffusion_img_variation_intermediate_state(self): + number_of_steps = 0 + + def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> None: + test_callback_fn.has_been_called = True + nonlocal number_of_steps + number_of_steps += 1 + if step == 0: + latents = latents.detach().cpu().numpy() + assert latents.shape == (1, 4, 64, 64) + latents_slice = latents[0, -3:, -3:, -1] + expected_slice = np.array([1.83, 1.293, -0.09705, 1.256, -2.293, 1.091, -0.0809, -0.65, -2.953]) + assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3 + elif step == 37: + latents = latents.detach().cpu().numpy() + assert latents.shape == (1, 4, 64, 64) + latents_slice = latents[0, -3:, -3:, -1] + expected_slice = np.array([2.285, 2.703, 1.969, 0.696, -1.323, 0.9253, -0.5464, -1.521, -2.537]) + assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-2 + + test_callback_fn.has_been_called = False + + init_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/img2img/sketch-mountains-input.jpg" + ) + init_image = init_image.resize((512, 512)) + + pipe = StableDiffusionImageVariationPipeline.from_pretrained( + "fusing/sd-image-variations-diffusers", + torch_dtype=torch.float16, + ) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + generator = torch.Generator(device=torch_device).manual_seed(0) + with torch.autocast(torch_device): + pipe( + init_image, + num_inference_steps=50, + guidance_scale=7.5, + generator=generator, + callback=test_callback_fn, + callback_steps=1, + ) + assert test_callback_fn.has_been_called + assert number_of_steps == 51 + + def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self): + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + torch.cuda.reset_peak_memory_stats() + + init_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/img2img/sketch-mountains-input.jpg" + ) + init_image = init_image.resize((512, 512)) + + model_id = "fusing/sd-image-variations-diffusers" + lms = LMSDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler") + pipe = StableDiffusionImageVariationPipeline.from_pretrained( + model_id, scheduler=lms, safety_checker=None, torch_dtype=torch.float16 + ) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing(1) + pipe.enable_sequential_cpu_offload() + + generator = torch.Generator(device=torch_device).manual_seed(0) + _ = pipe( + init_image, + guidance_scale=7.5, + generator=generator, + output_type="np", + num_inference_steps=5, + ) + + mem_bytes = torch.cuda.max_memory_allocated() + # make sure that less than 2.6 GB is allocated + assert mem_bytes < 2.6 * 10**9 diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py new file mode 100644 index 000000000000..d86b259eae9e --- /dev/null +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py @@ -0,0 +1,676 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import random +import unittest + +import numpy as np +import torch + +from diffusers import ( + AutoencoderKL, + DDIMScheduler, + LMSDiscreteScheduler, + PNDMScheduler, + StableDiffusionImg2ImgPipeline, + UNet2DConditionModel, + UNet2DModel, + VQModel, +) +from diffusers.utils import floats_tensor, load_image, load_numpy, slow, torch_device +from diffusers.utils.testing_utils import require_torch_gpu +from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer + +from ...test_pipelines_common import PipelineTesterMixin + + +torch.backends.cuda.matmul.allow_tf32 = False + + +class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + @property + def dummy_image(self): + batch_size = 1 + num_channels = 3 + sizes = (32, 32) + + image = floats_tensor((batch_size, num_channels) + sizes, rng=random.Random(0)).to(torch_device) + return image + + @property + def dummy_uncond_unet(self): + torch.manual_seed(0) + model = UNet2DModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=3, + out_channels=3, + down_block_types=("DownBlock2D", "AttnDownBlock2D"), + up_block_types=("AttnUpBlock2D", "UpBlock2D"), + ) + return model + + @property + def dummy_cond_unet(self): + torch.manual_seed(0) + model = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=32, + ) + return model + + @property + def dummy_cond_unet_inpaint(self): + torch.manual_seed(0) + model = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=9, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=32, + ) + return model + + @property + def dummy_vq_model(self): + torch.manual_seed(0) + model = VQModel( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=3, + ) + return model + + @property + def dummy_vae(self): + torch.manual_seed(0) + model = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + ) + return model + + @property + def dummy_text_encoder(self): + torch.manual_seed(0) + config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + ) + return CLIPTextModel(config) + + @property + def dummy_extractor(self): + def extract(*args, **kwargs): + class Out: + def __init__(self): + self.pixel_values = torch.ones([0]) + + def to(self, device): + self.pixel_values.to(device) + return self + + return Out() + + return extract + + def test_stable_diffusion_img2img_default_case(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unet = self.dummy_cond_unet + scheduler = PNDMScheduler(skip_prk_steps=True) + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + init_image = self.dummy_image.to(device) + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionImg2ImgPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=self.dummy_extractor, + ) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device=device).manual_seed(0) + output = sd_pipe( + [prompt], + generator=generator, + guidance_scale=6.0, + num_inference_steps=2, + output_type="np", + init_image=init_image, + ) + + image = output.images + + generator = torch.Generator(device=device).manual_seed(0) + image_from_tuple = sd_pipe( + [prompt], + generator=generator, + guidance_scale=6.0, + num_inference_steps=2, + output_type="np", + init_image=init_image, + return_dict=False, + )[0] + + image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + + assert image.shape == (1, 32, 32, 3) + expected_slice = np.array([0.4492, 0.3865, 0.4222, 0.5854, 0.5139, 0.4379, 0.4193, 0.48, 0.4218]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-3 + + def test_stable_diffusion_img2img_negative_prompt(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unet = self.dummy_cond_unet + scheduler = PNDMScheduler(skip_prk_steps=True) + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + init_image = self.dummy_image.to(device) + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionImg2ImgPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=self.dummy_extractor, + ) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + negative_prompt = "french fries" + generator = torch.Generator(device=device).manual_seed(0) + output = sd_pipe( + prompt, + negative_prompt=negative_prompt, + generator=generator, + guidance_scale=6.0, + num_inference_steps=2, + output_type="np", + init_image=init_image, + ) + image = output.images + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 32, 32, 3) + expected_slice = np.array([0.4065, 0.3783, 0.4050, 0.5266, 0.4781, 0.4252, 0.4203, 0.4692, 0.4365]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + + def test_stable_diffusion_img2img_multiple_init_images(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unet = self.dummy_cond_unet + scheduler = PNDMScheduler(skip_prk_steps=True) + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + init_image = self.dummy_image.to(device).repeat(2, 1, 1, 1) + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionImg2ImgPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=self.dummy_extractor, + ) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = 2 * ["A painting of a squirrel eating a burger"] + generator = torch.Generator(device=device).manual_seed(0) + output = sd_pipe( + prompt, + generator=generator, + guidance_scale=6.0, + num_inference_steps=2, + output_type="np", + init_image=init_image, + ) + + image = output.images + + image_slice = image[-1, -3:, -3:, -1] + + assert image.shape == (2, 32, 32, 3) + expected_slice = np.array([0.5144, 0.4447, 0.4735, 0.6676, 0.5526, 0.5454, 0.645, 0.5149, 0.4689]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + + def test_stable_diffusion_img2img_k_lms(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unet = self.dummy_cond_unet + scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear") + + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + init_image = self.dummy_image.to(device) + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionImg2ImgPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=self.dummy_extractor, + ) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device=device).manual_seed(0) + output = sd_pipe( + [prompt], + generator=generator, + guidance_scale=6.0, + num_inference_steps=2, + output_type="np", + init_image=init_image, + ) + image = output.images + + generator = torch.Generator(device=device).manual_seed(0) + output = sd_pipe( + [prompt], + generator=generator, + guidance_scale=6.0, + num_inference_steps=2, + output_type="np", + init_image=init_image, + return_dict=False, + ) + image_from_tuple = output[0] + + image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + + assert image.shape == (1, 32, 32, 3) + expected_slice = np.array([0.4367, 0.4986, 0.4372, 0.6706, 0.5665, 0.444, 0.5864, 0.6019, 0.5203]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-3 + + def test_stable_diffusion_img2img_num_images_per_prompt(self): + device = "cpu" + unet = self.dummy_cond_unet + scheduler = PNDMScheduler(skip_prk_steps=True) + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + init_image = self.dummy_image.to(device) + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionImg2ImgPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=self.dummy_extractor, + ) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + + # test num_images_per_prompt=1 (default) + images = sd_pipe( + prompt, + num_inference_steps=2, + output_type="np", + init_image=init_image, + ).images + + assert images.shape == (1, 32, 32, 3) + + # test num_images_per_prompt=1 (default) for batch of prompts + batch_size = 2 + images = sd_pipe( + [prompt] * batch_size, + num_inference_steps=2, + output_type="np", + init_image=init_image, + ).images + + assert images.shape == (batch_size, 32, 32, 3) + + # test num_images_per_prompt for single prompt + num_images_per_prompt = 2 + images = sd_pipe( + prompt, + num_inference_steps=2, + output_type="np", + init_image=init_image, + num_images_per_prompt=num_images_per_prompt, + ).images + + assert images.shape == (num_images_per_prompt, 32, 32, 3) + + # test num_images_per_prompt for batch of prompts + batch_size = 2 + images = sd_pipe( + [prompt] * batch_size, + num_inference_steps=2, + output_type="np", + init_image=init_image, + num_images_per_prompt=num_images_per_prompt, + ).images + + assert images.shape == (batch_size * num_images_per_prompt, 32, 32, 3) + + @unittest.skipIf(torch_device != "cuda", "This test requires a GPU") + def test_stable_diffusion_img2img_fp16(self): + """Test that stable diffusion img2img works with fp16""" + unet = self.dummy_cond_unet + scheduler = PNDMScheduler(skip_prk_steps=True) + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + init_image = self.dummy_image.to(torch_device) + + # put models in fp16 + unet = unet.half() + vae = vae.half() + bert = bert.half() + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionImg2ImgPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=self.dummy_extractor, + ) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device=torch_device).manual_seed(0) + image = sd_pipe( + [prompt], + generator=generator, + num_inference_steps=2, + output_type="np", + init_image=init_image, + ).images + + assert image.shape == (1, 32, 32, 3) + + +@slow +@require_torch_gpu +class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_stable_diffusion_img2img_pipeline_default(self): + init_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/img2img/sketch-mountains-input.jpg" + ) + init_image = init_image.resize((768, 512)) + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/img2img/fantasy_landscape.npy" + ) + + model_id = "CompVis/stable-diffusion-v1-4" + pipe = StableDiffusionImg2ImgPipeline.from_pretrained( + model_id, + safety_checker=None, + ) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + prompt = "A fantasy landscape, trending on artstation" + + generator = torch.Generator(device=torch_device).manual_seed(0) + output = pipe( + prompt=prompt, + init_image=init_image, + strength=0.75, + guidance_scale=7.5, + generator=generator, + output_type="np", + ) + image = output.images[0] + + assert image.shape == (512, 768, 3) + # img2img is flaky across GPUs even in fp32, so using MAE here + assert np.abs(expected_image - image).max() < 1e-3 + + def test_stable_diffusion_img2img_pipeline_k_lms(self): + init_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/img2img/sketch-mountains-input.jpg" + ) + init_image = init_image.resize((768, 512)) + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/img2img/fantasy_landscape_k_lms.npy" + ) + + model_id = "CompVis/stable-diffusion-v1-4" + lms = LMSDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler") + pipe = StableDiffusionImg2ImgPipeline.from_pretrained( + model_id, + scheduler=lms, + safety_checker=None, + ) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + prompt = "A fantasy landscape, trending on artstation" + + generator = torch.Generator(device=torch_device).manual_seed(0) + output = pipe( + prompt=prompt, + init_image=init_image, + strength=0.75, + guidance_scale=7.5, + generator=generator, + output_type="np", + ) + image = output.images[0] + + assert image.shape == (512, 768, 3) + assert np.abs(expected_image - image).max() < 1e-3 + + def test_stable_diffusion_img2img_pipeline_ddim(self): + init_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/img2img/sketch-mountains-input.jpg" + ) + init_image = init_image.resize((768, 512)) + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/img2img/fantasy_landscape_ddim.npy" + ) + + model_id = "CompVis/stable-diffusion-v1-4" + ddim = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler") + pipe = StableDiffusionImg2ImgPipeline.from_pretrained( + model_id, + scheduler=ddim, + safety_checker=None, + ) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + prompt = "A fantasy landscape, trending on artstation" + + generator = torch.Generator(device=torch_device).manual_seed(0) + output = pipe( + prompt=prompt, + init_image=init_image, + strength=0.75, + guidance_scale=7.5, + generator=generator, + output_type="np", + ) + image = output.images[0] + + assert image.shape == (512, 768, 3) + assert np.abs(expected_image - image).max() < 1e-3 + + def test_stable_diffusion_img2img_intermediate_state(self): + number_of_steps = 0 + + def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> None: + test_callback_fn.has_been_called = True + nonlocal number_of_steps + number_of_steps += 1 + if step == 0: + latents = latents.detach().cpu().numpy() + assert latents.shape == (1, 4, 64, 96) + latents_slice = latents[0, -3:, -3:, -1] + expected_slice = np.array([0.9052, -0.0184, 0.4810, 0.2898, 0.5851, 1.4920, 0.5362, 1.9838, 0.0530]) + assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3 + elif step == 37: + latents = latents.detach().cpu().numpy() + assert latents.shape == (1, 4, 64, 96) + latents_slice = latents[0, -3:, -3:, -1] + expected_slice = np.array([0.7071, 0.7831, 0.8300, 1.8140, 1.7840, 1.9402, 1.3651, 1.6590, 1.2828]) + assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-2 + + test_callback_fn.has_been_called = False + + init_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/img2img/sketch-mountains-input.jpg" + ) + init_image = init_image.resize((768, 512)) + + pipe = StableDiffusionImg2ImgPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", + revision="fp16", + torch_dtype=torch.float16, + ) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + prompt = "A fantasy landscape, trending on artstation" + + generator = torch.Generator(device=torch_device).manual_seed(0) + with torch.autocast(torch_device): + pipe( + prompt=prompt, + init_image=init_image, + strength=0.75, + num_inference_steps=50, + guidance_scale=7.5, + generator=generator, + callback=test_callback_fn, + callback_steps=1, + ) + assert test_callback_fn.has_been_called + assert number_of_steps == 38 + + def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self): + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + torch.cuda.reset_peak_memory_stats() + + init_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/img2img/sketch-mountains-input.jpg" + ) + init_image = init_image.resize((768, 512)) + + model_id = "CompVis/stable-diffusion-v1-4" + lms = LMSDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler") + pipe = StableDiffusionImg2ImgPipeline.from_pretrained( + model_id, scheduler=lms, safety_checker=None, device_map="auto", revision="fp16", torch_dtype=torch.float16 + ) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing(1) + pipe.enable_sequential_cpu_offload() + + prompt = "A fantasy landscape, trending on artstation" + + generator = torch.Generator(device=torch_device).manual_seed(0) + _ = pipe( + prompt=prompt, + init_image=init_image, + strength=0.75, + guidance_scale=7.5, + generator=generator, + output_type="np", + num_inference_steps=5, + ) + + mem_bytes = torch.cuda.max_memory_allocated() + # make sure that less than 2.2 GB is allocated + assert mem_bytes < 2.2 * 10**9 diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py new file mode 100644 index 000000000000..e85fae939e9a --- /dev/null +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py @@ -0,0 +1,679 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import random +import unittest + +import numpy as np +import torch + +from diffusers import ( + AutoencoderKL, + LMSDiscreteScheduler, + PNDMScheduler, + StableDiffusionInpaintPipeline, + UNet2DConditionModel, + UNet2DModel, + VQModel, +) +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import prepare_mask_and_masked_image +from diffusers.utils import floats_tensor, load_image, load_numpy, slow, torch_device +from diffusers.utils.testing_utils import require_torch_gpu +from PIL import Image +from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer + +from ...test_pipelines_common import PipelineTesterMixin + + +torch.backends.cuda.matmul.allow_tf32 = False + + +class StableDiffusionInpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + @property + def dummy_image(self): + batch_size = 1 + num_channels = 3 + sizes = (32, 32) + + image = floats_tensor((batch_size, num_channels) + sizes, rng=random.Random(0)).to(torch_device) + return image + + @property + def dummy_uncond_unet(self): + torch.manual_seed(0) + model = UNet2DModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=3, + out_channels=3, + down_block_types=("DownBlock2D", "AttnDownBlock2D"), + up_block_types=("AttnUpBlock2D", "UpBlock2D"), + ) + return model + + @property + def dummy_cond_unet(self): + torch.manual_seed(0) + model = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=32, + ) + return model + + @property + def dummy_cond_unet_inpaint(self): + torch.manual_seed(0) + model = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=9, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=32, + ) + return model + + @property + def dummy_vq_model(self): + torch.manual_seed(0) + model = VQModel( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=3, + ) + return model + + @property + def dummy_vae(self): + torch.manual_seed(0) + model = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + ) + return model + + @property + def dummy_text_encoder(self): + torch.manual_seed(0) + config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + ) + return CLIPTextModel(config) + + @property + def dummy_extractor(self): + def extract(*args, **kwargs): + class Out: + def __init__(self): + self.pixel_values = torch.ones([0]) + + def to(self, device): + self.pixel_values.to(device) + return self + + return Out() + + return extract + + def test_stable_diffusion_inpaint(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unet = self.dummy_cond_unet_inpaint + scheduler = PNDMScheduler(skip_prk_steps=True) + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0] + init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64)) + mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((64, 64)) + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionInpaintPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=None, + ) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device=device).manual_seed(0) + output = sd_pipe( + [prompt], + generator=generator, + guidance_scale=6.0, + num_inference_steps=2, + output_type="np", + image=init_image, + mask_image=mask_image, + ) + + image = output.images + + generator = torch.Generator(device=device).manual_seed(0) + image_from_tuple = sd_pipe( + [prompt], + generator=generator, + guidance_scale=6.0, + num_inference_steps=2, + output_type="np", + image=init_image, + mask_image=mask_image, + return_dict=False, + )[0] + + image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + expected_slice = np.array([0.4723, 0.5731, 0.3939, 0.5441, 0.5922, 0.4392, 0.5059, 0.4651, 0.4474]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + + def test_stable_diffusion_inpaint_with_num_images_per_prompt(self): + device = "cpu" + unet = self.dummy_cond_unet_inpaint + scheduler = PNDMScheduler(skip_prk_steps=True) + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0] + init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64)) + mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((64, 64)) + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionInpaintPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=None, + ) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device=device).manual_seed(0) + images = sd_pipe( + [prompt], + generator=generator, + guidance_scale=6.0, + num_inference_steps=2, + output_type="np", + image=init_image, + mask_image=mask_image, + num_images_per_prompt=2, + ).images + + # check if the output is a list of 2 images + assert len(images) == 2 + + @unittest.skipIf(torch_device != "cuda", "This test requires a GPU") + def test_stable_diffusion_inpaint_fp16(self): + """Test that stable diffusion inpaint_legacy works with fp16""" + unet = self.dummy_cond_unet_inpaint + scheduler = PNDMScheduler(skip_prk_steps=True) + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0] + init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64)) + mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((64, 64)) + + # put models in fp16 + unet = unet.half() + vae = vae.half() + bert = bert.half() + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionInpaintPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=None, + ) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device=torch_device).manual_seed(0) + image = sd_pipe( + [prompt], + generator=generator, + num_inference_steps=2, + output_type="np", + image=init_image, + mask_image=mask_image, + ).images + + assert image.shape == (1, 64, 64, 3) + + +@slow +@require_torch_gpu +class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_stable_diffusion_inpaint_pipeline(self): + init_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/in_paint/overture-creations-5sI6fQgYIuo.png" + ) + mask_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/in_paint/overture-creations-5sI6fQgYIuo_mask.png" + ) + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/in_paint" + "/yellow_cat_sitting_on_a_park_bench.npy" + ) + + model_id = "runwayml/stable-diffusion-inpainting" + pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id, safety_checker=None) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + prompt = "Face of a yellow cat, high resolution, sitting on a park bench" + + generator = torch.Generator(device=torch_device).manual_seed(0) + output = pipe( + prompt=prompt, + image=init_image, + mask_image=mask_image, + generator=generator, + output_type="np", + ) + image = output.images[0] + + assert image.shape == (512, 512, 3) + assert np.abs(expected_image - image).max() < 1e-3 + + def test_stable_diffusion_inpaint_pipeline_fp16(self): + init_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/in_paint/overture-creations-5sI6fQgYIuo.png" + ) + mask_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/in_paint/overture-creations-5sI6fQgYIuo_mask.png" + ) + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/in_paint" + "/yellow_cat_sitting_on_a_park_bench_fp16.npy" + ) + + model_id = "runwayml/stable-diffusion-inpainting" + pipe = StableDiffusionInpaintPipeline.from_pretrained( + model_id, + revision="fp16", + torch_dtype=torch.float16, + safety_checker=None, + ) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + prompt = "Face of a yellow cat, high resolution, sitting on a park bench" + + generator = torch.Generator(device=torch_device).manual_seed(0) + output = pipe( + prompt=prompt, + image=init_image, + mask_image=mask_image, + generator=generator, + output_type="np", + ) + image = output.images[0] + + assert image.shape == (512, 512, 3) + assert np.abs(expected_image - image).max() < 5e-1 + + def test_stable_diffusion_inpaint_pipeline_pndm(self): + init_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/in_paint/overture-creations-5sI6fQgYIuo.png" + ) + mask_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/in_paint/overture-creations-5sI6fQgYIuo_mask.png" + ) + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/in_paint" + "/yellow_cat_sitting_on_a_park_bench_pndm.npy" + ) + + model_id = "runwayml/stable-diffusion-inpainting" + pndm = PNDMScheduler.from_pretrained(model_id, subfolder="scheduler") + pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id, safety_checker=None, scheduler=pndm) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + prompt = "Face of a yellow cat, high resolution, sitting on a park bench" + + generator = torch.Generator(device=torch_device).manual_seed(0) + output = pipe( + prompt=prompt, + image=init_image, + mask_image=mask_image, + generator=generator, + output_type="np", + ) + image = output.images[0] + + assert image.shape == (512, 512, 3) + assert np.abs(expected_image - image).max() < 1e-2 + + def test_stable_diffusion_inpaint_pipeline_k_lms(self): + init_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/in_paint/overture-creations-5sI6fQgYIuo.png" + ) + mask_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/in_paint/overture-creations-5sI6fQgYIuo_mask.png" + ) + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/in_paint" + "/yellow_cat_sitting_on_a_park_bench_k_lms.npy" + ) + + model_id = "runwayml/stable-diffusion-inpainting" + pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id, safety_checker=None) + pipe.to(torch_device) + + # switch to LMS + pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config) + + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + prompt = "Face of a yellow cat, high resolution, sitting on a park bench" + + generator = torch.Generator(device=torch_device).manual_seed(0) + output = pipe( + prompt=prompt, + image=init_image, + mask_image=mask_image, + generator=generator, + output_type="np", + ) + image = output.images[0] + + assert image.shape == (512, 512, 3) + assert np.abs(expected_image - image).max() < 1e-2 + + @unittest.skipIf(torch_device == "cpu", "This test is supposed to run on GPU") + def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self): + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + torch.cuda.reset_peak_memory_stats() + + init_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/in_paint/overture-creations-5sI6fQgYIuo.png" + ) + mask_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/in_paint/overture-creations-5sI6fQgYIuo_mask.png" + ) + + model_id = "runwayml/stable-diffusion-inpainting" + pndm = PNDMScheduler.from_pretrained(model_id, subfolder="scheduler") + pipe = StableDiffusionInpaintPipeline.from_pretrained( + model_id, + safety_checker=None, + scheduler=pndm, + device_map="auto", + revision="fp16", + torch_dtype=torch.float16, + ) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing(1) + pipe.enable_sequential_cpu_offload() + + prompt = "Face of a yellow cat, high resolution, sitting on a park bench" + + generator = torch.Generator(device=torch_device).manual_seed(0) + _ = pipe( + prompt=prompt, + image=init_image, + mask_image=mask_image, + generator=generator, + num_inference_steps=5, + output_type="np", + ) + + mem_bytes = torch.cuda.max_memory_allocated() + # make sure that less than 2.2 GB is allocated + assert mem_bytes < 2.2 * 10**9 + + +class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests(unittest.TestCase): + def test_pil_inputs(self): + im = np.random.randint(0, 255, (32, 32, 3), dtype=np.uint8) + im = Image.fromarray(im) + mask = np.random.randint(0, 255, (32, 32), dtype=np.uint8) > 127.5 + mask = Image.fromarray((mask * 255).astype(np.uint8)) + + t_mask, t_masked = prepare_mask_and_masked_image(im, mask) + + self.assertTrue(isinstance(t_mask, torch.Tensor)) + self.assertTrue(isinstance(t_masked, torch.Tensor)) + + self.assertEqual(t_mask.ndim, 4) + self.assertEqual(t_masked.ndim, 4) + + self.assertEqual(t_mask.shape, (1, 1, 32, 32)) + self.assertEqual(t_masked.shape, (1, 3, 32, 32)) + + self.assertTrue(t_mask.dtype == torch.float32) + self.assertTrue(t_masked.dtype == torch.float32) + + self.assertTrue(t_mask.min() >= 0.0) + self.assertTrue(t_mask.max() <= 1.0) + self.assertTrue(t_masked.min() >= -1.0) + self.assertTrue(t_masked.min() <= 1.0) + + self.assertTrue(t_mask.sum() > 0.0) + + def test_np_inputs(self): + im_np = np.random.randint(0, 255, (32, 32, 3), dtype=np.uint8) + im_pil = Image.fromarray(im_np) + mask_np = np.random.randint(0, 255, (32, 32), dtype=np.uint8) > 127.5 + mask_pil = Image.fromarray((mask_np * 255).astype(np.uint8)) + + t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np) + t_mask_pil, t_masked_pil = prepare_mask_and_masked_image(im_pil, mask_pil) + + self.assertTrue((t_mask_np == t_mask_pil).all()) + self.assertTrue((t_masked_np == t_masked_pil).all()) + + def test_torch_3D_2D_inputs(self): + im_tensor = torch.randint(0, 255, (3, 32, 32), dtype=torch.uint8) + mask_tensor = torch.randint(0, 255, (32, 32), dtype=torch.uint8) > 127.5 + im_np = im_tensor.numpy().transpose(1, 2, 0) + mask_np = mask_tensor.numpy() + + t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor) + t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np) + + self.assertTrue((t_mask_tensor == t_mask_np).all()) + self.assertTrue((t_masked_tensor == t_masked_np).all()) + + def test_torch_3D_3D_inputs(self): + im_tensor = torch.randint(0, 255, (3, 32, 32), dtype=torch.uint8) + mask_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8) > 127.5 + im_np = im_tensor.numpy().transpose(1, 2, 0) + mask_np = mask_tensor.numpy()[0] + + t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor) + t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np) + + self.assertTrue((t_mask_tensor == t_mask_np).all()) + self.assertTrue((t_masked_tensor == t_masked_np).all()) + + def test_torch_4D_2D_inputs(self): + im_tensor = torch.randint(0, 255, (1, 3, 32, 32), dtype=torch.uint8) + mask_tensor = torch.randint(0, 255, (32, 32), dtype=torch.uint8) > 127.5 + im_np = im_tensor.numpy()[0].transpose(1, 2, 0) + mask_np = mask_tensor.numpy() + + t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor) + t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np) + + self.assertTrue((t_mask_tensor == t_mask_np).all()) + self.assertTrue((t_masked_tensor == t_masked_np).all()) + + def test_torch_4D_3D_inputs(self): + im_tensor = torch.randint(0, 255, (1, 3, 32, 32), dtype=torch.uint8) + mask_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8) > 127.5 + im_np = im_tensor.numpy()[0].transpose(1, 2, 0) + mask_np = mask_tensor.numpy()[0] + + t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor) + t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np) + + self.assertTrue((t_mask_tensor == t_mask_np).all()) + self.assertTrue((t_masked_tensor == t_masked_np).all()) + + def test_torch_4D_4D_inputs(self): + im_tensor = torch.randint(0, 255, (1, 3, 32, 32), dtype=torch.uint8) + mask_tensor = torch.randint(0, 255, (1, 1, 32, 32), dtype=torch.uint8) > 127.5 + im_np = im_tensor.numpy()[0].transpose(1, 2, 0) + mask_np = mask_tensor.numpy()[0][0] + + t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor) + t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np) + + self.assertTrue((t_mask_tensor == t_mask_np).all()) + self.assertTrue((t_masked_tensor == t_masked_np).all()) + + def test_torch_batch_4D_3D(self): + im_tensor = torch.randint(0, 255, (2, 3, 32, 32), dtype=torch.uint8) + mask_tensor = torch.randint(0, 255, (2, 32, 32), dtype=torch.uint8) > 127.5 + + im_nps = [im.numpy().transpose(1, 2, 0) for im in im_tensor] + mask_nps = [mask.numpy() for mask in mask_tensor] + + t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor) + nps = [prepare_mask_and_masked_image(i, m) for i, m in zip(im_nps, mask_nps)] + t_mask_np = torch.cat([n[0] for n in nps]) + t_masked_np = torch.cat([n[1] for n in nps]) + + self.assertTrue((t_mask_tensor == t_mask_np).all()) + self.assertTrue((t_masked_tensor == t_masked_np).all()) + + def test_torch_batch_4D_4D(self): + im_tensor = torch.randint(0, 255, (2, 3, 32, 32), dtype=torch.uint8) + mask_tensor = torch.randint(0, 255, (2, 1, 32, 32), dtype=torch.uint8) > 127.5 + + im_nps = [im.numpy().transpose(1, 2, 0) for im in im_tensor] + mask_nps = [mask.numpy()[0] for mask in mask_tensor] + + t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor) + nps = [prepare_mask_and_masked_image(i, m) for i, m in zip(im_nps, mask_nps)] + t_mask_np = torch.cat([n[0] for n in nps]) + t_masked_np = torch.cat([n[1] for n in nps]) + + self.assertTrue((t_mask_tensor == t_mask_np).all()) + self.assertTrue((t_masked_tensor == t_masked_np).all()) + + def test_shape_mismatch(self): + # test height and width + with self.assertRaises(AssertionError): + prepare_mask_and_masked_image(torch.randn(3, 32, 32), torch.randn(64, 64)) + # test batch dim + with self.assertRaises(AssertionError): + prepare_mask_and_masked_image(torch.randn(2, 3, 32, 32), torch.randn(4, 64, 64)) + # test batch dim + with self.assertRaises(AssertionError): + prepare_mask_and_masked_image(torch.randn(2, 3, 32, 32), torch.randn(4, 1, 64, 64)) + + def test_type_mismatch(self): + # test tensors-only + with self.assertRaises(TypeError): + prepare_mask_and_masked_image(torch.rand(3, 32, 32), torch.rand(3, 32, 32).numpy()) + # test tensors-only + with self.assertRaises(TypeError): + prepare_mask_and_masked_image(torch.rand(3, 32, 32).numpy(), torch.rand(3, 32, 32)) + + def test_channels_first(self): + # test channels first for 3D tensors + with self.assertRaises(AssertionError): + prepare_mask_and_masked_image(torch.rand(32, 32, 3), torch.rand(3, 32, 32)) + + def test_tensor_range(self): + # test im <= 1 + with self.assertRaises(ValueError): + prepare_mask_and_masked_image(torch.ones(3, 32, 32) * 2, torch.rand(32, 32)) + # test im >= -1 + with self.assertRaises(ValueError): + prepare_mask_and_masked_image(torch.ones(3, 32, 32) * (-2), torch.rand(32, 32)) + # test mask <= 1 + with self.assertRaises(ValueError): + prepare_mask_and_masked_image(torch.rand(3, 32, 32), torch.ones(32, 32) * 2) + # test mask >= 0 + with self.assertRaises(ValueError): + prepare_mask_and_masked_image(torch.rand(3, 32, 32), torch.ones(32, 32) * -1) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint_legacy.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint_legacy.py new file mode 100644 index 000000000000..4b972c7b7d2a --- /dev/null +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint_legacy.py @@ -0,0 +1,487 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import random +import unittest + +import numpy as np +import torch + +from diffusers import ( + AutoencoderKL, + LMSDiscreteScheduler, + PNDMScheduler, + StableDiffusionInpaintPipeline, + StableDiffusionInpaintPipelineLegacy, + UNet2DConditionModel, + UNet2DModel, + VQModel, +) +from diffusers.utils import floats_tensor, load_image, slow, torch_device +from diffusers.utils.testing_utils import load_numpy, require_torch_gpu +from PIL import Image +from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer + +from ...test_pipelines_common import PipelineTesterMixin + + +torch.backends.cuda.matmul.allow_tf32 = False + + +class StableDiffusionInpaintLegacyPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + @property + def dummy_image(self): + batch_size = 1 + num_channels = 3 + sizes = (32, 32) + + image = floats_tensor((batch_size, num_channels) + sizes, rng=random.Random(0)).to(torch_device) + return image + + @property + def dummy_uncond_unet(self): + torch.manual_seed(0) + model = UNet2DModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=3, + out_channels=3, + down_block_types=("DownBlock2D", "AttnDownBlock2D"), + up_block_types=("AttnUpBlock2D", "UpBlock2D"), + ) + return model + + @property + def dummy_cond_unet(self): + torch.manual_seed(0) + model = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=32, + ) + return model + + @property + def dummy_cond_unet_inpaint(self): + torch.manual_seed(0) + model = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=9, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=32, + ) + return model + + @property + def dummy_vq_model(self): + torch.manual_seed(0) + model = VQModel( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=3, + ) + return model + + @property + def dummy_vae(self): + torch.manual_seed(0) + model = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + ) + return model + + @property + def dummy_text_encoder(self): + torch.manual_seed(0) + config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + ) + return CLIPTextModel(config) + + @property + def dummy_extractor(self): + def extract(*args, **kwargs): + class Out: + def __init__(self): + self.pixel_values = torch.ones([0]) + + def to(self, device): + self.pixel_values.to(device) + return self + + return Out() + + return extract + + def test_stable_diffusion_inpaint_legacy(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unet = self.dummy_cond_unet + scheduler = PNDMScheduler(skip_prk_steps=True) + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0] + init_image = Image.fromarray(np.uint8(image)).convert("RGB") + mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((32, 32)) + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionInpaintPipelineLegacy( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=self.dummy_extractor, + ) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device=device).manual_seed(0) + output = sd_pipe( + [prompt], + generator=generator, + guidance_scale=6.0, + num_inference_steps=2, + output_type="np", + init_image=init_image, + mask_image=mask_image, + ) + + image = output.images + + generator = torch.Generator(device=device).manual_seed(0) + image_from_tuple = sd_pipe( + [prompt], + generator=generator, + guidance_scale=6.0, + num_inference_steps=2, + output_type="np", + init_image=init_image, + mask_image=mask_image, + return_dict=False, + )[0] + + image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + + assert image.shape == (1, 32, 32, 3) + expected_slice = np.array([0.4731, 0.5346, 0.4531, 0.6251, 0.5446, 0.4057, 0.5527, 0.5896, 0.5153]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + + def test_stable_diffusion_inpaint_legacy_negative_prompt(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unet = self.dummy_cond_unet + scheduler = PNDMScheduler(skip_prk_steps=True) + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0] + init_image = Image.fromarray(np.uint8(image)).convert("RGB") + mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((32, 32)) + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionInpaintPipelineLegacy( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=self.dummy_extractor, + ) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + negative_prompt = "french fries" + generator = torch.Generator(device=device).manual_seed(0) + output = sd_pipe( + prompt, + negative_prompt=negative_prompt, + generator=generator, + guidance_scale=6.0, + num_inference_steps=2, + output_type="np", + init_image=init_image, + mask_image=mask_image, + ) + + image = output.images + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 32, 32, 3) + expected_slice = np.array([0.4765, 0.5339, 0.4541, 0.6240, 0.5439, 0.4055, 0.5503, 0.5891, 0.5150]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + def test_stable_diffusion_inpaint_legacy_num_images_per_prompt(self): + device = "cpu" + unet = self.dummy_cond_unet + scheduler = PNDMScheduler(skip_prk_steps=True) + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0] + init_image = Image.fromarray(np.uint8(image)).convert("RGB") + mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((32, 32)) + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionInpaintPipelineLegacy( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=self.dummy_extractor, + ) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + + # test num_images_per_prompt=1 (default) + images = sd_pipe( + prompt, + num_inference_steps=2, + output_type="np", + init_image=init_image, + mask_image=mask_image, + ).images + + assert images.shape == (1, 32, 32, 3) + + # test num_images_per_prompt=1 (default) for batch of prompts + batch_size = 2 + images = sd_pipe( + [prompt] * batch_size, + num_inference_steps=2, + output_type="np", + init_image=init_image, + mask_image=mask_image, + ).images + + assert images.shape == (batch_size, 32, 32, 3) + + # test num_images_per_prompt for single prompt + num_images_per_prompt = 2 + images = sd_pipe( + prompt, + num_inference_steps=2, + output_type="np", + init_image=init_image, + mask_image=mask_image, + num_images_per_prompt=num_images_per_prompt, + ).images + + assert images.shape == (num_images_per_prompt, 32, 32, 3) + + # test num_images_per_prompt for batch of prompts + batch_size = 2 + images = sd_pipe( + [prompt] * batch_size, + num_inference_steps=2, + output_type="np", + init_image=init_image, + mask_image=mask_image, + num_images_per_prompt=num_images_per_prompt, + ).images + + assert images.shape == (batch_size * num_images_per_prompt, 32, 32, 3) + + +@slow +@require_torch_gpu +class StableDiffusionInpaintLegacyPipelineIntegrationTests(unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_stable_diffusion_inpaint_legacy_pipeline(self): + init_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/in_paint/overture-creations-5sI6fQgYIuo.png" + ) + mask_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/in_paint/overture-creations-5sI6fQgYIuo_mask.png" + ) + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/in_paint" + "/red_cat_sitting_on_a_park_bench.npy" + ) + + model_id = "CompVis/stable-diffusion-v1-4" + pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id, safety_checker=None) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + prompt = "A red cat sitting on a park bench" + + generator = torch.Generator(device=torch_device).manual_seed(0) + output = pipe( + prompt=prompt, + init_image=init_image, + mask_image=mask_image, + strength=0.75, + guidance_scale=7.5, + generator=generator, + output_type="np", + ) + image = output.images[0] + + assert image.shape == (512, 512, 3) + assert np.abs(expected_image - image).max() < 1e-3 + + def test_stable_diffusion_inpaint_legacy_pipeline_k_lms(self): + init_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/in_paint/overture-creations-5sI6fQgYIuo.png" + ) + mask_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/in_paint/overture-creations-5sI6fQgYIuo_mask.png" + ) + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/in_paint" + "/red_cat_sitting_on_a_park_bench_k_lms.npy" + ) + + model_id = "CompVis/stable-diffusion-v1-4" + lms = LMSDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler") + pipe = StableDiffusionInpaintPipeline.from_pretrained( + model_id, + scheduler=lms, + safety_checker=None, + ) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + prompt = "A red cat sitting on a park bench" + + generator = torch.Generator(device=torch_device).manual_seed(0) + output = pipe( + prompt=prompt, + init_image=init_image, + mask_image=mask_image, + strength=0.75, + guidance_scale=7.5, + generator=generator, + output_type="np", + ) + image = output.images[0] + + assert image.shape == (512, 512, 3) + assert np.abs(expected_image - image).max() < 1e-3 + + def test_stable_diffusion_inpaint_legacy_intermediate_state(self): + number_of_steps = 0 + + def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> None: + test_callback_fn.has_been_called = True + nonlocal number_of_steps + number_of_steps += 1 + if step == 0: + latents = latents.detach().cpu().numpy() + assert latents.shape == (1, 4, 64, 64) + latents_slice = latents[0, -3:, -3:, -1] + expected_slice = np.array( + [-0.5472, 1.1218, -0.5505, -0.9390, -1.0794, 0.4063, 0.5158, 0.6429, -1.5246] + ) + assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3 + elif step == 37: + latents = latents.detach().cpu().numpy() + assert latents.shape == (1, 4, 64, 64) + latents_slice = latents[0, -3:, -3:, -1] + expected_slice = np.array([0.4781, 1.1572, 0.6258, 0.2291, 0.2554, -0.1443, 0.7085, -0.1598, -0.5659]) + assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3 + + test_callback_fn.has_been_called = False + + init_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/in_paint/overture-creations-5sI6fQgYIuo.png" + ) + mask_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/in_paint/overture-creations-5sI6fQgYIuo_mask.png" + ) + + pipe = StableDiffusionInpaintPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16 + ) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + prompt = "A red cat sitting on a park bench" + + generator = torch.Generator(device=torch_device).manual_seed(0) + with torch.autocast(torch_device): + pipe( + prompt=prompt, + init_image=init_image, + mask_image=mask_image, + strength=0.75, + num_inference_steps=50, + guidance_scale=7.5, + generator=generator, + callback=test_callback_fn, + callback_steps=1, + ) + assert test_callback_fn.has_been_called + assert number_of_steps == 38 diff --git a/tests/pipelines/stable_diffusion_2/__init__.py b/tests/pipelines/stable_diffusion_2/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py new file mode 100644 index 000000000000..fad3b89f056b --- /dev/null +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py @@ -0,0 +1,733 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import tempfile +import time +import unittest + +import numpy as np +import torch + +from diffusers import ( + AutoencoderKL, + DDIMScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, + StableDiffusionPipeline, + UNet2DConditionModel, + logging, +) +from diffusers.utils import load_numpy, slow, torch_device +from diffusers.utils.testing_utils import CaptureLogger, require_torch_gpu +from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer + +from ...test_pipelines_common import PipelineTesterMixin + + +torch.backends.cuda.matmul.allow_tf32 = False + + +class StableDiffusion2PipelineFastTests(PipelineTesterMixin, unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + @property + def dummy_cond_unet(self): + torch.manual_seed(0) + model = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=32, + # SD2-specific config below + attention_head_dim=(2, 4, 8, 8), + use_linear_projection=True, + ) + return model + + @property + def dummy_vae(self): + torch.manual_seed(0) + model = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + sample_size=128, + ) + return model + + @property + def dummy_text_encoder(self): + torch.manual_seed(0) + config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + # SD2-specific config below + hidden_act="gelu", + projection_dim=512, + ) + return CLIPTextModel(config) + + def test_save_pretrained_from_pretrained(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unet = self.dummy_cond_unet + scheduler = DDIMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + ) + + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + ) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + + generator = torch.Generator(device=device).manual_seed(0) + output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np") + image = output.images + + with tempfile.TemporaryDirectory() as tmpdirname: + sd_pipe.save_pretrained(tmpdirname) + sd_pipe = StableDiffusionPipeline.from_pretrained(tmpdirname) + sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + generator = generator.manual_seed(0) + output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np") + new_image = output.images + + assert np.abs(image - new_image).sum() < 1e-5, "Models don't have the same forward pass" + + def test_stable_diffusion_ddim(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unet = self.dummy_cond_unet + scheduler = DDIMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + ) + + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + ) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + + generator = torch.Generator(device=device).manual_seed(0) + output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np") + image = output.images + + generator = torch.Generator(device=device).manual_seed(0) + image_from_tuple = sd_pipe( + [prompt], + generator=generator, + guidance_scale=6.0, + num_inference_steps=2, + output_type="np", + return_dict=False, + )[0] + + image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + expected_slice = np.array([0.5649, 0.6022, 0.4804, 0.5270, 0.5585, 0.4643, 0.5159, 0.4963, 0.4793]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + + def test_stable_diffusion_pndm(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unet = self.dummy_cond_unet + scheduler = PNDMScheduler(skip_prk_steps=True) + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + ) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device=device).manual_seed(0) + output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np") + + image = output.images + + generator = torch.Generator(device=device).manual_seed(0) + image_from_tuple = sd_pipe( + [prompt], + generator=generator, + guidance_scale=6.0, + num_inference_steps=2, + output_type="np", + return_dict=False, + )[0] + + image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + expected_slice = np.array([0.5099, 0.5677, 0.4671, 0.5128, 0.5697, 0.4676, 0.5277, 0.4964, 0.4946]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + + def test_stable_diffusion_k_lms(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unet = self.dummy_cond_unet + scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear") + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + ) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device=device).manual_seed(0) + output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np") + + image = output.images + + generator = torch.Generator(device=device).manual_seed(0) + image_from_tuple = sd_pipe( + [prompt], + generator=generator, + guidance_scale=6.0, + num_inference_steps=2, + output_type="np", + return_dict=False, + )[0] + + image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + expected_slice = np.array([0.4717, 0.5376, 0.4568, 0.5225, 0.5734, 0.4797, 0.5467, 0.5074, 0.5043]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + + def test_stable_diffusion_k_euler_ancestral(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unet = self.dummy_cond_unet + scheduler = EulerAncestralDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear") + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + ) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device=device).manual_seed(0) + output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np") + + image = output.images + + generator = torch.Generator(device=device).manual_seed(0) + image_from_tuple = sd_pipe( + [prompt], + generator=generator, + guidance_scale=6.0, + num_inference_steps=2, + output_type="np", + return_dict=False, + )[0] + + image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + expected_slice = np.array([0.4715, 0.5376, 0.4569, 0.5224, 0.5734, 0.4797, 0.5465, 0.5074, 0.5046]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + + def test_stable_diffusion_k_euler(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unet = self.dummy_cond_unet + scheduler = EulerDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear") + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + ) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device=device).manual_seed(0) + output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np") + + image = output.images + + generator = torch.Generator(device=device).manual_seed(0) + image_from_tuple = sd_pipe( + [prompt], + generator=generator, + guidance_scale=6.0, + num_inference_steps=2, + output_type="np", + return_dict=False, + )[0] + + image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + expected_slice = np.array([0.4717, 0.5376, 0.4568, 0.5225, 0.5734, 0.4797, 0.5467, 0.5074, 0.5043]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + + def test_stable_diffusion_attention_chunk(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unet = self.dummy_cond_unet + scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear") + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + ) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device=device).manual_seed(0) + output_1 = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np") + + # make sure chunking the attention yields the same result + sd_pipe.enable_attention_slicing(slice_size=1) + generator = torch.Generator(device=device).manual_seed(0) + output_2 = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np") + + assert np.abs(output_2.images.flatten() - output_1.images.flatten()).max() < 1e-4 + + @unittest.skipIf(torch_device != "cuda", "This test requires a GPU") + def test_stable_diffusion_fp16(self): + """Test that stable diffusion works with fp16""" + unet = self.dummy_cond_unet + scheduler = PNDMScheduler(skip_prk_steps=True) + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + # put models in fp16 + unet = unet.half() + vae = vae.half() + bert = bert.half() + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + ) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device=torch_device).manual_seed(0) + image = sd_pipe([prompt], generator=generator, num_inference_steps=2, output_type="np").images + + assert image.shape == (1, 64, 64, 3) + + def test_stable_diffusion_long_prompt(self): + unet = self.dummy_cond_unet + scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear") + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + ) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + do_classifier_free_guidance = True + negative_prompt = None + num_images_per_prompt = 1 + logger = logging.get_logger("diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion") + + prompt = 25 * "@" + with CaptureLogger(logger) as cap_logger_3: + text_embeddings_3 = sd_pipe._encode_prompt( + prompt, torch_device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + prompt = 100 * "@" + with CaptureLogger(logger) as cap_logger: + text_embeddings = sd_pipe._encode_prompt( + prompt, torch_device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + negative_prompt = "Hello" + with CaptureLogger(logger) as cap_logger_2: + text_embeddings_2 = sd_pipe._encode_prompt( + prompt, torch_device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + assert text_embeddings_3.shape == text_embeddings_2.shape == text_embeddings.shape + assert text_embeddings.shape[1] == 77 + + assert cap_logger.out == cap_logger_2.out + # 100 - 77 + 1 (BOS token) + 1 (EOS token) = 25 + assert cap_logger.out.count("@") == 25 + assert cap_logger_3.out == "" + + +@slow +@require_torch_gpu +class StableDiffusion2PipelineIntegrationTests(unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_stable_diffusion(self): + sd_pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-base") + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device=torch_device).manual_seed(0) + output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=20, output_type="np") + + image = output.images + image_slice = image[0, 253:256, 253:256, -1] + + assert image.shape == (1, 512, 512, 3) + expected_slice = np.array([0.0788, 0.0823, 0.1091, 0.1165, 0.1263, 0.1459, 0.1317, 0.1507, 0.1551]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + def test_stable_diffusion_ddim(self): + scheduler = DDIMScheduler.from_pretrained("stabilityai/stable-diffusion-2-base", subfolder="scheduler") + sd_pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-base", scheduler=scheduler) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device=torch_device).manual_seed(0) + + output = sd_pipe([prompt], generator=generator, num_inference_steps=5, output_type="numpy") + image = output.images + + image_slice = image[0, 253:256, 253:256, -1] + + assert image.shape == (1, 512, 512, 3) + expected_slice = np.array([0.0642, 0.0382, 0.0408, 0.0395, 0.0227, 0.0942, 0.0749, 0.0669, 0.0248]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + def test_stable_diffusion_k_lms(self): + scheduler = LMSDiscreteScheduler.from_pretrained("stabilityai/stable-diffusion-2-base", subfolder="scheduler") + sd_pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-base", scheduler=scheduler) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "a photograph of an astronaut riding a horse" + generator = torch.Generator(device=torch_device).manual_seed(0) + image = sd_pipe( + [prompt], generator=generator, guidance_scale=7.5, num_inference_steps=5, output_type="numpy" + ).images + + image_slice = image[0, 253:256, 253:256, -1] + assert image.shape == (1, 512, 512, 3) + expected_slice = np.array([0.0548, 0.0626, 0.0612, 0.0611, 0.0706, 0.0586, 0.0843, 0.0333, 0.1197]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + def test_stable_diffusion_attention_slicing(self): + torch.cuda.reset_peak_memory_stats() + model_id = "stabilityai/stable-diffusion-2-base" + pipe = StableDiffusionPipeline.from_pretrained(model_id, revision="fp16", torch_dtype=torch.float16) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + prompt = "a photograph of an astronaut riding a horse" + + # make attention efficient + pipe.enable_attention_slicing() + generator = torch.Generator(device=torch_device).manual_seed(0) + with torch.autocast(torch_device): + output_chunked = pipe( + [prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy" + ) + image_chunked = output_chunked.images + + mem_bytes = torch.cuda.max_memory_allocated() + torch.cuda.reset_peak_memory_stats() + # make sure that less than 3.75 GB is allocated + assert mem_bytes < 3.75 * 10**9 + + # disable chunking + pipe.disable_attention_slicing() + generator = torch.Generator(device=torch_device).manual_seed(0) + with torch.autocast(torch_device): + output = pipe( + [prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy" + ) + image = output.images + + # make sure that more than 3.75 GB is allocated + mem_bytes = torch.cuda.max_memory_allocated() + assert mem_bytes > 3.75 * 10**9 + assert np.abs(image_chunked.flatten() - image.flatten()).max() < 1e-3 + + def test_stable_diffusion_same_quality(self): + torch.cuda.reset_peak_memory_stats() + model_id = "stabilityai/stable-diffusion-2-base" + pipe = StableDiffusionPipeline.from_pretrained(model_id, revision="fp16", torch_dtype=torch.float16) + pipe = pipe.to(torch_device) + pipe.enable_attention_slicing() + pipe.set_progress_bar_config(disable=None) + + prompt = "a photograph of an astronaut riding a horse" + + generator = torch.Generator(device=torch_device).manual_seed(0) + output_chunked = pipe( + [prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy" + ) + image_chunked = output_chunked.images + + pipe = StableDiffusionPipeline.from_pretrained(model_id) + pipe = pipe.to(torch_device) + generator = torch.Generator(device=torch_device).manual_seed(0) + output = pipe([prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy") + image = output.images + + # Make sure results are close enough + diff = np.abs(image_chunked.flatten() - image.flatten()) + # They ARE different since ops are not run always at the same precision + # however, they should be extremely close. + assert diff.mean() < 5e-2 + + def test_stable_diffusion_text2img_pipeline_default(self): + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-text2img/astronaut_riding_a_horse.npy" + ) + + model_id = "stabilityai/stable-diffusion-2-base" + pipe = StableDiffusionPipeline.from_pretrained(model_id, safety_checker=None) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + prompt = "astronaut riding a horse" + + generator = torch.Generator(device=torch_device).manual_seed(0) + output = pipe(prompt=prompt, guidance_scale=7.5, generator=generator, output_type="np") + image = output.images[0] + + assert image.shape == (512, 512, 3) + assert np.abs(expected_image - image).max() < 5e-3 + + def test_stable_diffusion_text2img_intermediate_state(self): + number_of_steps = 0 + + def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> None: + test_callback_fn.has_been_called = True + nonlocal number_of_steps + number_of_steps += 1 + if step == 0: + latents = latents.detach().cpu().numpy() + assert latents.shape == (1, 4, 64, 64) + latents_slice = latents[0, -3:, -3:, -1] + expected_slice = np.array([1.8606, 1.3169, -0.0691, 1.2374, -2.309, 1.077, -0.1084, -0.6774, -2.9594]) + assert np.abs(latents_slice.flatten() - expected_slice).max() < 5e-3 + elif step == 20: + latents = latents.detach().cpu().numpy() + assert latents.shape == (1, 4, 64, 64) + latents_slice = latents[0, -3:, -3:, -1] + expected_slice = np.array([1.0757, 1.1860, 1.1410, 0.4645, -0.2476, 0.6100, -0.7755, -0.8841, -0.9497]) + assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-2 + + test_callback_fn.has_been_called = False + + pipe = StableDiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-2-base", revision="fp16", torch_dtype=torch.float16 + ) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + prompt = "Andromeda galaxy in a bottle" + + generator = torch.Generator(device=torch_device).manual_seed(0) + with torch.autocast(torch_device): + pipe( + prompt=prompt, + num_inference_steps=20, + guidance_scale=7.5, + generator=generator, + callback=test_callback_fn, + callback_steps=1, + ) + assert test_callback_fn.has_been_called + assert number_of_steps == 21 + + def test_stable_diffusion_low_cpu_mem_usage(self): + pipeline_id = "stabilityai/stable-diffusion-2-base" + + start_time = time.time() + pipeline_low_cpu_mem_usage = StableDiffusionPipeline.from_pretrained( + pipeline_id, revision="fp16", torch_dtype=torch.float16 + ) + pipeline_low_cpu_mem_usage.to(torch_device) + low_cpu_mem_usage_time = time.time() - start_time + + start_time = time.time() + _ = StableDiffusionPipeline.from_pretrained( + pipeline_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True, low_cpu_mem_usage=False + ) + normal_load_time = time.time() - start_time + + assert 2 * low_cpu_mem_usage_time < normal_load_time + + def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self): + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + torch.cuda.reset_peak_memory_stats() + + pipeline_id = "stabilityai/stable-diffusion-2-base" + prompt = "Andromeda galaxy in a bottle" + + pipeline = StableDiffusionPipeline.from_pretrained(pipeline_id, revision="fp16", torch_dtype=torch.float16) + pipeline = pipeline.to(torch_device) + pipeline.enable_attention_slicing(1) + pipeline.enable_sequential_cpu_offload() + + generator = torch.Generator(device=torch_device).manual_seed(0) + _ = pipeline(prompt, generator=generator, num_inference_steps=5) + + mem_bytes = torch.cuda.max_memory_allocated() + # make sure that less than 2.8 GB is allocated + assert mem_bytes < 2.8 * 10**9 diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py new file mode 100644 index 000000000000..b420570f0707 --- /dev/null +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py @@ -0,0 +1,345 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import random +import unittest + +import numpy as np +import torch + +from diffusers import AutoencoderKL, PNDMScheduler, StableDiffusionInpaintPipeline, UNet2DConditionModel +from diffusers.utils import floats_tensor, load_image, load_numpy, torch_device +from diffusers.utils.testing_utils import require_torch_gpu +from PIL import Image +from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer + +from ...test_pipelines_common import PipelineTesterMixin + + +torch.backends.cuda.matmul.allow_tf32 = False + + +class StableDiffusionInpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + @property + def dummy_image(self): + batch_size = 1 + num_channels = 3 + sizes = (32, 32) + + image = floats_tensor((batch_size, num_channels) + sizes, rng=random.Random(0)).to(torch_device) + return image + + @property + def dummy_cond_unet_inpaint(self): + torch.manual_seed(0) + model = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=9, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=32, + # SD2-specific config below + attention_head_dim=(2, 4, 8, 8), + use_linear_projection=True, + ) + return model + + @property + def dummy_vae(self): + torch.manual_seed(0) + model = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + ) + return model + + @property + def dummy_text_encoder(self): + torch.manual_seed(0) + config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + # SD2-specific config below + hidden_act="gelu", + projection_dim=512, + ) + return CLIPTextModel(config) + + @property + def dummy_extractor(self): + def extract(*args, **kwargs): + class Out: + def __init__(self): + self.pixel_values = torch.ones([0]) + + def to(self, device): + self.pixel_values.to(device) + return self + + return Out() + + return extract + + def test_stable_diffusion_inpaint(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unet = self.dummy_cond_unet_inpaint + scheduler = PNDMScheduler(skip_prk_steps=True) + vae = self.dummy_vae + text_encoder = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0] + init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64)) + mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((64, 64)) + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionInpaintPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=None, + ) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device=device).manual_seed(0) + output = sd_pipe( + [prompt], + generator=generator, + guidance_scale=6.0, + num_inference_steps=2, + output_type="np", + image=init_image, + mask_image=mask_image, + ) + + image = output.images + + generator = torch.Generator(device=device).manual_seed(0) + image_from_tuple = sd_pipe( + [prompt], + generator=generator, + guidance_scale=6.0, + num_inference_steps=2, + output_type="np", + image=init_image, + mask_image=mask_image, + return_dict=False, + )[0] + + image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + expected_slice = np.array([0.4727, 0.5735, 0.3941, 0.5446, 0.5926, 0.4394, 0.5062, 0.4654, 0.4476]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + + @unittest.skipIf(torch_device != "cuda", "This test requires a GPU") + def test_stable_diffusion_inpaint_fp16(self): + """Test that stable diffusion inpaint works with fp16""" + unet = self.dummy_cond_unet_inpaint + scheduler = PNDMScheduler(skip_prk_steps=True) + vae = self.dummy_vae + text_encoder = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0] + init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64)) + mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((64, 64)) + + # put models in fp16 + unet = unet.half() + vae = vae.half() + text_encoder = text_encoder.half() + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionInpaintPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=None, + ) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device=torch_device).manual_seed(0) + image = sd_pipe( + [prompt], + generator=generator, + num_inference_steps=2, + output_type="np", + image=init_image, + mask_image=mask_image, + ).images + + assert image.shape == (1, 64, 64, 3) + + +# @slow +@require_torch_gpu +class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_stable_diffusion_inpaint_pipeline(self): + init_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/sd2-inpaint/init_image.png" + ) + mask_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-inpaint/mask.png" + ) + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-inpaint" + "/yellow_cat_sitting_on_a_park_bench.npy" + ) + + model_id = "stabilityai/stable-diffusion-2-inpainting" + pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id, safety_checker=None) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + prompt = "Face of a yellow cat, high resolution, sitting on a park bench" + + generator = torch.Generator(device=torch_device).manual_seed(0) + output = pipe( + prompt=prompt, + image=init_image, + mask_image=mask_image, + generator=generator, + output_type="np", + ) + image = output.images[0] + + assert image.shape == (512, 512, 3) + assert np.abs(expected_image - image).max() < 1e-3 + + def test_stable_diffusion_inpaint_pipeline_fp16(self): + init_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/sd2-inpaint/init_image.png" + ) + mask_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-inpaint/mask.png" + ) + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-inpaint" + "/yellow_cat_sitting_on_a_park_bench_fp16.npy" + ) + + model_id = "stabilityai/stable-diffusion-2-inpainting" + pipe = StableDiffusionInpaintPipeline.from_pretrained( + model_id, + revision="fp16", + torch_dtype=torch.float16, + safety_checker=None, + ) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + prompt = "Face of a yellow cat, high resolution, sitting on a park bench" + + generator = torch.Generator(device=torch_device).manual_seed(0) + output = pipe( + prompt=prompt, + image=init_image, + mask_image=mask_image, + generator=generator, + output_type="np", + ) + image = output.images[0] + + assert image.shape == (512, 512, 3) + assert np.abs(expected_image - image).max() < 5e-1 + + def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self): + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + torch.cuda.reset_peak_memory_stats() + + init_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/sd2-inpaint/init_image.png" + ) + mask_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-inpaint/mask.png" + ) + + model_id = "stabilityai/stable-diffusion-2-inpainting" + pndm = PNDMScheduler.from_pretrained(model_id, subfolder="scheduler") + pipe = StableDiffusionInpaintPipeline.from_pretrained( + model_id, + safety_checker=None, + scheduler=pndm, + device_map="auto", + revision="fp16", + torch_dtype=torch.float16, + ) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing(1) + pipe.enable_sequential_cpu_offload() + + prompt = "Face of a yellow cat, high resolution, sitting on a park bench" + + generator = torch.Generator(device=torch_device).manual_seed(0) + _ = pipe( + prompt=prompt, + image=init_image, + mask_image=mask_image, + generator=generator, + num_inference_steps=5, + output_type="np", + ) + + mem_bytes = torch.cuda.max_memory_allocated() + # make sure that less than 2.65 GB is allocated + assert mem_bytes < 2.65 * 10**9 diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_upscale.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_upscale.py new file mode 100644 index 000000000000..2092e153eeb2 --- /dev/null +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_upscale.py @@ -0,0 +1,315 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import random +import unittest + +import numpy as np +import torch + +from diffusers import AutoencoderKL, DDIMScheduler, DDPMScheduler, StableDiffusionUpscalePipeline, UNet2DConditionModel +from diffusers.utils import floats_tensor, load_image, load_numpy, slow, torch_device +from diffusers.utils.testing_utils import require_torch_gpu +from PIL import Image +from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer + +from ...test_pipelines_common import PipelineTesterMixin + + +torch.backends.cuda.matmul.allow_tf32 = False + + +class StableDiffusionUpscalePipelineFastTests(PipelineTesterMixin, unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + @property + def dummy_image(self): + batch_size = 1 + num_channels = 3 + sizes = (32, 32) + + image = floats_tensor((batch_size, num_channels) + sizes, rng=random.Random(0)).to(torch_device) + return image + + @property + def dummy_cond_unet_upscale(self): + torch.manual_seed(0) + model = UNet2DConditionModel( + block_out_channels=(32, 32, 64), + layers_per_block=2, + sample_size=32, + in_channels=7, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=32, + # SD2-specific config below + attention_head_dim=8, + use_linear_projection=True, + only_cross_attention=(True, True, False), + num_class_embeds=100, + ) + return model + + @property + def dummy_vae(self): + torch.manual_seed(0) + model = AutoencoderKL( + block_out_channels=[32, 32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + ) + return model + + @property + def dummy_text_encoder(self): + torch.manual_seed(0) + config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + # SD2-specific config below + hidden_act="gelu", + projection_dim=512, + ) + return CLIPTextModel(config) + + def test_stable_diffusion_upscale(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unet = self.dummy_cond_unet_upscale + low_res_scheduler = DDPMScheduler() + scheduler = DDIMScheduler(prediction_type="v_prediction") + vae = self.dummy_vae + text_encoder = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0] + low_res_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64)) + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionUpscalePipeline( + unet=unet, + low_res_scheduler=low_res_scheduler, + scheduler=scheduler, + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + max_noise_level=350, + ) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device=device).manual_seed(0) + output = sd_pipe( + [prompt], + image=low_res_image, + generator=generator, + guidance_scale=6.0, + noise_level=20, + num_inference_steps=2, + output_type="np", + ) + + image = output.images + + generator = torch.Generator(device=device).manual_seed(0) + image_from_tuple = sd_pipe( + [prompt], + image=low_res_image, + generator=generator, + guidance_scale=6.0, + noise_level=20, + num_inference_steps=2, + output_type="np", + return_dict=False, + )[0] + + image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + + expected_height_width = low_res_image.size[0] * 4 + assert image.shape == (1, expected_height_width, expected_height_width, 3) + expected_slice = np.array([0.2562, 0.3606, 0.4204, 0.4469, 0.4822, 0.4647, 0.5315, 0.5748, 0.5606]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + + @unittest.skipIf(torch_device != "cuda", "This test requires a GPU") + def test_stable_diffusion_upscale_fp16(self): + """Test that stable diffusion upscale works with fp16""" + unet = self.dummy_cond_unet_upscale + low_res_scheduler = DDPMScheduler() + scheduler = DDIMScheduler(prediction_type="v_prediction") + vae = self.dummy_vae + text_encoder = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0] + low_res_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64)) + + # put models in fp16, except vae as it overflows in fp16 + unet = unet.half() + text_encoder = text_encoder.half() + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionUpscalePipeline( + unet=unet, + low_res_scheduler=low_res_scheduler, + scheduler=scheduler, + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + max_noise_level=350, + ) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device=torch_device).manual_seed(0) + image = sd_pipe( + [prompt], + image=low_res_image, + generator=generator, + num_inference_steps=2, + output_type="np", + ).images + + expected_height_width = low_res_image.size[0] * 4 + assert image.shape == (1, expected_height_width, expected_height_width, 3) + + +@slow +@require_torch_gpu +class StableDiffusionUpscalePipelineIntegrationTests(unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_stable_diffusion_upscale_pipeline(self): + image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/sd2-upscale/low_res_cat.png" + ) + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-upscale" + "/upsampled_cat.npy" + ) + + model_id = "stabilityai/stable-diffusion-x4-upscaler" + pipe = StableDiffusionUpscalePipeline.from_pretrained(model_id) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + prompt = "a cat sitting on a park bench" + + generator = torch.Generator(device=torch_device).manual_seed(0) + output = pipe( + prompt=prompt, + image=image, + generator=generator, + output_type="np", + ) + image = output.images[0] + + assert image.shape == (512, 512, 3) + assert np.abs(expected_image - image).max() < 1e-3 + + def test_stable_diffusion_upscale_pipeline_fp16(self): + image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/sd2-upscale/low_res_cat.png" + ) + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-upscale" + "/upsampled_cat_fp16.npy" + ) + + model_id = "stabilityai/stable-diffusion-x4-upscaler" + pipe = StableDiffusionUpscalePipeline.from_pretrained( + model_id, + revision="fp16", + torch_dtype=torch.float16, + ) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + prompt = "a cat sitting on a park bench" + + generator = torch.Generator(device=torch_device).manual_seed(0) + output = pipe( + prompt=prompt, + image=image, + generator=generator, + output_type="np", + ) + image = output.images[0] + + assert image.shape == (512, 512, 3) + assert np.abs(expected_image - image).max() < 5e-1 + + def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self): + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + torch.cuda.reset_peak_memory_stats() + + image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/sd2-upscale/low_res_cat.png" + ) + + model_id = "stabilityai/stable-diffusion-x4-upscaler" + pipe = StableDiffusionUpscalePipeline.from_pretrained( + model_id, + revision="fp16", + torch_dtype=torch.float16, + ) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing(1) + pipe.enable_sequential_cpu_offload() + + prompt = "a cat sitting on a park bench" + + generator = torch.Generator(device=torch_device).manual_seed(0) + _ = pipe( + prompt=prompt, + image=image, + generator=generator, + num_inference_steps=5, + output_type="np", + ) + + mem_bytes = torch.cuda.max_memory_allocated() + # make sure that less than 2.65 GB is allocated + assert mem_bytes < 2.65 * 10**9 diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py new file mode 100644 index 000000000000..cfc450db4a86 --- /dev/null +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py @@ -0,0 +1,474 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import time +import unittest + +import numpy as np +import torch + +from diffusers import ( + AutoencoderKL, + DDIMScheduler, + DPMSolverMultistepScheduler, + EulerDiscreteScheduler, + StableDiffusionPipeline, + UNet2DConditionModel, +) +from diffusers.utils import load_numpy, slow, torch_device +from diffusers.utils.testing_utils import require_torch_gpu +from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer + +from ...test_pipelines_common import PipelineTesterMixin + + +torch.backends.cuda.matmul.allow_tf32 = False + + +class StableDiffusion2VPredictionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + @property + def dummy_cond_unet(self): + torch.manual_seed(0) + model = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=32, + # SD2-specific config below + attention_head_dim=(2, 4, 8, 8), + use_linear_projection=True, + ) + return model + + @property + def dummy_vae(self): + torch.manual_seed(0) + model = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + sample_size=128, + ) + return model + + @property + def dummy_text_encoder(self): + torch.manual_seed(0) + config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + # SD2-specific config below + hidden_act="gelu", + projection_dim=64, + ) + return CLIPTextModel(config) + + def test_stable_diffusion_v_pred_ddim(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unet = self.dummy_cond_unet + scheduler = DDIMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + prediction_type="v_prediction", + ) + + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + ) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + + generator = torch.Generator(device=device).manual_seed(0) + output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np") + image = output.images + + generator = torch.Generator(device=device).manual_seed(0) + image_from_tuple = sd_pipe( + [prompt], + generator=generator, + guidance_scale=6.0, + num_inference_steps=2, + output_type="np", + return_dict=False, + )[0] + + image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + expected_slice = np.array([0.6424, 0.6109, 0.494, 0.5088, 0.4984, 0.4525, 0.5059, 0.5068, 0.4474]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + + def test_stable_diffusion_v_pred_k_euler(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unet = self.dummy_cond_unet + scheduler = EulerDiscreteScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", prediction_type="v_prediction" + ) + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + ) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device=device).manual_seed(0) + output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np") + + image = output.images + + generator = torch.Generator(device=device).manual_seed(0) + image_from_tuple = sd_pipe( + [prompt], + generator=generator, + guidance_scale=6.0, + num_inference_steps=2, + output_type="np", + return_dict=False, + )[0] + + image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + expected_slice = np.array([0.4616, 0.5184, 0.4887, 0.5111, 0.4839, 0.48, 0.5119, 0.5263, 0.4776]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + + @unittest.skipIf(torch_device != "cuda", "This test requires a GPU") + def test_stable_diffusion_v_pred_fp16(self): + """Test that stable diffusion v-prediction works with fp16""" + unet = self.dummy_cond_unet + scheduler = DDIMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + prediction_type="v_prediction", + ) + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + # put models in fp16 + unet = unet.half() + vae = vae.half() + bert = bert.half() + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + ) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device=torch_device).manual_seed(0) + image = sd_pipe([prompt], generator=generator, num_inference_steps=2, output_type="np").images + + assert image.shape == (1, 64, 64, 3) + + +@slow +@require_torch_gpu +class StableDiffusion2VPredictionPipelineIntegrationTests(unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_stable_diffusion_v_pred_default(self): + sd_pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2") + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.enable_attention_slicing() + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device=torch_device).manual_seed(0) + output = sd_pipe([prompt], generator=generator, guidance_scale=7.5, num_inference_steps=20, output_type="np") + + image = output.images + image_slice = image[0, 253:256, 253:256, -1] + + assert image.shape == (1, 768, 768, 3) + expected_slice = np.array([0.0567, 0.057, 0.0416, 0.0463, 0.0433, 0.06, 0.0517, 0.0526, 0.0866]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + def test_stable_diffusion_v_pred_euler(self): + scheduler = EulerDiscreteScheduler.from_pretrained("stabilityai/stable-diffusion-2", subfolder="scheduler") + sd_pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2", scheduler=scheduler) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.enable_attention_slicing() + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device=torch_device).manual_seed(0) + + output = sd_pipe([prompt], generator=generator, num_inference_steps=5, output_type="numpy") + image = output.images + + image_slice = image[0, 253:256, 253:256, -1] + + assert image.shape == (1, 768, 768, 3) + expected_slice = np.array([0.0351, 0.0376, 0.0505, 0.0424, 0.0551, 0.0656, 0.0471, 0.0276, 0.0596]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + def test_stable_diffusion_v_pred_dpm(self): + """ + TODO: update this test after making DPM compatible with V-prediction! + """ + scheduler = DPMSolverMultistepScheduler.from_pretrained( + "stabilityai/stable-diffusion-2", subfolder="scheduler" + ) + sd_pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2", scheduler=scheduler) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.enable_attention_slicing() + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "a photograph of an astronaut riding a horse" + generator = torch.Generator(device=torch_device).manual_seed(0) + image = sd_pipe( + [prompt], generator=generator, guidance_scale=7.5, num_inference_steps=5, output_type="numpy" + ).images + + image_slice = image[0, 253:256, 253:256, -1] + assert image.shape == (1, 768, 768, 3) + expected_slice = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + def test_stable_diffusion_attention_slicing_v_pred(self): + torch.cuda.reset_peak_memory_stats() + model_id = "stabilityai/stable-diffusion-2" + pipe = StableDiffusionPipeline.from_pretrained(model_id, revision="fp16", torch_dtype=torch.float16) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + prompt = "a photograph of an astronaut riding a horse" + + # make attention efficient + pipe.enable_attention_slicing() + generator = torch.Generator(device=torch_device).manual_seed(0) + with torch.autocast(torch_device): + output_chunked = pipe( + [prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy" + ) + image_chunked = output_chunked.images + + mem_bytes = torch.cuda.max_memory_allocated() + torch.cuda.reset_peak_memory_stats() + # make sure that less than 5.5 GB is allocated + assert mem_bytes < 5.5 * 10**9 + + # disable slicing + pipe.disable_attention_slicing() + generator = torch.Generator(device=torch_device).manual_seed(0) + with torch.autocast(torch_device): + output = pipe( + [prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy" + ) + image = output.images + + # make sure that more than 5.5 GB is allocated + mem_bytes = torch.cuda.max_memory_allocated() + assert mem_bytes > 5.5 * 10**9 + assert np.abs(image_chunked.flatten() - image.flatten()).max() < 1e-3 + + def test_stable_diffusion_text2img_pipeline_v_pred_default(self): + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/" + "sd2-text2img/astronaut_riding_a_horse_v_pred.npy" + ) + + pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2") + pipe.to(torch_device) + pipe.enable_attention_slicing() + pipe.set_progress_bar_config(disable=None) + + prompt = "astronaut riding a horse" + + generator = torch.Generator(device=torch_device).manual_seed(0) + output = pipe(prompt=prompt, guidance_scale=7.5, generator=generator, output_type="np") + image = output.images[0] + + assert image.shape == (768, 768, 3) + assert np.abs(expected_image - image).max() < 5e-3 + + def test_stable_diffusion_text2img_pipeline_v_pred_fp16(self): + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/" + "sd2-text2img/astronaut_riding_a_horse_v_pred_fp16.npy" + ) + + pipe = StableDiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-2", revision="fp16", torch_dtype=torch.float16 + ) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + prompt = "astronaut riding a horse" + + generator = torch.Generator(device=torch_device).manual_seed(0) + output = pipe(prompt=prompt, guidance_scale=7.5, generator=generator, output_type="np") + image = output.images[0] + + assert image.shape == (768, 768, 3) + assert np.abs(expected_image - image).max() < 5e-3 + + def test_stable_diffusion_text2img_intermediate_state_v_pred(self): + number_of_steps = 0 + + def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> None: + test_callback_fn.has_been_called = True + nonlocal number_of_steps + number_of_steps += 1 + if step == 0: + latents = latents.detach().cpu().numpy() + assert latents.shape == (1, 4, 96, 96) + latents_slice = latents[0, -3:, -3:, -1] + expected_slice = np.array( + [-0.2543, -1.2755, 0.4261, -0.9555, -1.173, -0.5892, 2.4159, 0.1554, -1.2098] + ) + assert np.abs(latents_slice.flatten() - expected_slice).max() < 5e-3 + elif step == 19: + latents = latents.detach().cpu().numpy() + assert latents.shape == (1, 4, 96, 96) + latents_slice = latents[0, -3:, -3:, -1] + expected_slice = np.array( + [-0.9572, -0.967, -0.6152, 0.0894, -0.699, -0.2344, 1.5465, -0.0357, -0.1141] + ) + assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-2 + + test_callback_fn.has_been_called = False + + pipe = StableDiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-2", revision="fp16", torch_dtype=torch.float16 + ) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + prompt = "Andromeda galaxy in a bottle" + + generator = torch.Generator(device=torch_device).manual_seed(0) + with torch.autocast(torch_device): + pipe( + prompt=prompt, + num_inference_steps=20, + guidance_scale=7.5, + generator=generator, + callback=test_callback_fn, + callback_steps=1, + ) + assert test_callback_fn.has_been_called + assert number_of_steps == 20 + + def test_stable_diffusion_low_cpu_mem_usage_v_pred(self): + pipeline_id = "stabilityai/stable-diffusion-2" + + start_time = time.time() + pipeline_low_cpu_mem_usage = StableDiffusionPipeline.from_pretrained( + pipeline_id, revision="fp16", torch_dtype=torch.float16 + ) + pipeline_low_cpu_mem_usage.to(torch_device) + low_cpu_mem_usage_time = time.time() - start_time + + start_time = time.time() + _ = StableDiffusionPipeline.from_pretrained( + pipeline_id, revision="fp16", torch_dtype=torch.float16, low_cpu_mem_usage=False + ) + normal_load_time = time.time() - start_time + + assert 2 * low_cpu_mem_usage_time < normal_load_time + + def test_stable_diffusion_pipeline_with_sequential_cpu_offloading_v_pred(self): + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + torch.cuda.reset_peak_memory_stats() + + pipeline_id = "stabilityai/stable-diffusion-2" + prompt = "Andromeda galaxy in a bottle" + + pipeline = StableDiffusionPipeline.from_pretrained(pipeline_id, revision="fp16", torch_dtype=torch.float16) + pipeline = pipeline.to(torch_device) + pipeline.enable_attention_slicing(1) + pipeline.enable_sequential_cpu_offload() + + generator = torch.Generator(device=torch_device).manual_seed(0) + _ = pipeline(prompt, generator=generator, num_inference_steps=5) + + mem_bytes = torch.cuda.max_memory_allocated() + # make sure that less than 2.8 GB is allocated + assert mem_bytes < 2.8 * 10**9 diff --git a/tests/pipelines/stable_diffusion_safe/__init__.py b/tests/pipelines/stable_diffusion_safe/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/stable_diffusion_safe/test_safe_diffusion.py b/tests/pipelines/stable_diffusion_safe/test_safe_diffusion.py new file mode 100644 index 000000000000..dbb991479377 --- /dev/null +++ b/tests/pipelines/stable_diffusion_safe/test_safe_diffusion.py @@ -0,0 +1,435 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import random +import tempfile +import unittest + +import numpy as np +import torch + +from diffusers import AutoencoderKL, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel +from diffusers.pipelines.stable_diffusion_safe import StableDiffusionPipelineSafe as StableDiffusionPipeline +from diffusers.utils import floats_tensor, slow, torch_device +from diffusers.utils.testing_utils import require_torch_gpu +from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer + +from ...test_pipelines_common import PipelineTesterMixin + + +torch.backends.cuda.matmul.allow_tf32 = False + + +class SafeDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + @property + def dummy_image(self): + batch_size = 1 + num_channels = 3 + sizes = (32, 32) + + image = floats_tensor((batch_size, num_channels) + sizes, rng=random.Random(0)).to(torch_device) + return image + + @property + def dummy_cond_unet(self): + torch.manual_seed(0) + model = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=32, + ) + return model + + @property + def dummy_vae(self): + torch.manual_seed(0) + model = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + ) + return model + + @property + def dummy_text_encoder(self): + torch.manual_seed(0) + config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + ) + return CLIPTextModel(config) + + @property + def dummy_extractor(self): + def extract(*args, **kwargs): + class Out: + def __init__(self): + self.pixel_values = torch.ones([0]) + + def to(self, device): + self.pixel_values.to(device) + return self + + return Out() + + return extract + + def test_safe_diffusion_ddim(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unet = self.dummy_cond_unet + scheduler = DDIMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + ) + + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=self.dummy_extractor, + ) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + + generator = torch.Generator(device=device).manual_seed(0) + output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np") + image = output.images + + generator = torch.Generator(device=device).manual_seed(0) + image_from_tuple = sd_pipe( + [prompt], + generator=generator, + guidance_scale=6.0, + num_inference_steps=2, + output_type="np", + return_dict=False, + )[0] + + image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + expected_slice = np.array([0.5644, 0.6018, 0.4799, 0.5267, 0.5585, 0.4641, 0.516, 0.4964, 0.4792]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + + def test_stable_diffusion_pndm(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unet = self.dummy_cond_unet + scheduler = PNDMScheduler(skip_prk_steps=True) + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=self.dummy_extractor, + ) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device=device).manual_seed(0) + output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np") + + image = output.images + + generator = torch.Generator(device=device).manual_seed(0) + image_from_tuple = sd_pipe( + [prompt], + generator=generator, + guidance_scale=6.0, + num_inference_steps=2, + output_type="np", + return_dict=False, + )[0] + + image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + expected_slice = np.array([0.5095, 0.5674, 0.4668, 0.5126, 0.5697, 0.4675, 0.5278, 0.4964, 0.4945]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + + def test_stable_diffusion_no_safety_checker(self): + pipe = StableDiffusionPipeline.from_pretrained( + "hf-internal-testing/tiny-stable-diffusion-lms-pipe", safety_checker=None + ) + assert isinstance(pipe, StableDiffusionPipeline) + assert isinstance(pipe.scheduler, LMSDiscreteScheduler) + assert pipe.safety_checker is None + + image = pipe("example prompt", num_inference_steps=2).images[0] + assert image is not None + + # check that there's no error when saving a pipeline with one of the models being None + with tempfile.TemporaryDirectory() as tmpdirname: + pipe.save_pretrained(tmpdirname) + pipe = StableDiffusionPipeline.from_pretrained(tmpdirname) + + # sanity check that the pipeline still works + assert pipe.safety_checker is None + image = pipe("example prompt", num_inference_steps=2).images[0] + assert image is not None + + @unittest.skipIf(torch_device != "cuda", "This test requires a GPU") + def test_stable_diffusion_fp16(self): + """Test that stable diffusion works with fp16""" + unet = self.dummy_cond_unet + scheduler = PNDMScheduler(skip_prk_steps=True) + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + # put models in fp16 + unet = unet.half() + vae = vae.half() + bert = bert.half() + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=self.dummy_extractor, + ) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device=torch_device).manual_seed(0) + image = sd_pipe([prompt], generator=generator, num_inference_steps=2, output_type="np").images + + assert image.shape == (1, 64, 64, 3) + + +@slow +@require_torch_gpu +class SafeDiffusionPipelineIntegrationTests(unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_harm_safe_stable_diffusion(self): + sd_pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", safety_checker=None) + sd_pipe.scheduler = LMSDiscreteScheduler.from_config(sd_pipe.scheduler.config) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = ( + "portrait of girl with smokey eyes makeup in abandoned hotel, grange clothes, redshift, wide high angle" + " coloured polaroid photograph with flash, kodak film, hyper real, stunning moody cinematography, with" + " anamorphic lenses, by maripol, fallen angels by wong kar - wai, style of suspiria and neon demon and" + " children from bahnhof zoo, detailed " + ) + seed = 4003660346 + guidance_scale = 7 + + # without safety guidance (sld_guidance_scale = 0) + generator = torch.Generator(device=torch_device).manual_seed(seed) + output = sd_pipe( + [prompt], + generator=generator, + guidance_scale=guidance_scale, + num_inference_steps=50, + output_type="np", + width=512, + height=512, + sld_guidance_scale=0, + ) + + image = output.images + image_slice = image[0, -3:, -3:, -1] + expected_slice = [0.2278, 0.2231, 0.2249, 0.2333, 0.2303, 0.1885, 0.2273, 0.2144, 0.2176] + + assert image.shape == (1, 512, 512, 3) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + # without safety guidance (strong configuration) + generator = torch.Generator(device=torch_device).manual_seed(seed) + output = sd_pipe( + [prompt], + generator=generator, + guidance_scale=guidance_scale, + num_inference_steps=50, + output_type="np", + width=512, + height=512, + sld_guidance_scale=2000, + sld_warmup_steps=7, + sld_threshold=0.025, + sld_momentum_scale=0.5, + sld_mom_beta=0.7, + ) + + image = output.images + image_slice = image[0, -3:, -3:, -1] + expected_slice = [0.2383, 0.2276, 0.236, 0.2192, 0.2186, 0.2053, 0.1971, 0.1901, 0.1719] + + assert image.shape == (1, 512, 512, 3) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + def test_nudity_safe_stable_diffusion(self): + sd_pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", safety_checker=None) + sd_pipe.scheduler = LMSDiscreteScheduler.from_config(sd_pipe.scheduler.config) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "padme amidala taking a bath artwork, safe for work, no nudity" + seed = 2734971755 + guidance_scale = 7 + + generator = torch.Generator(device=torch_device).manual_seed(seed) + output = sd_pipe( + [prompt], + generator=generator, + guidance_scale=guidance_scale, + num_inference_steps=50, + output_type="np", + width=512, + height=512, + sld_guidance_scale=0, + ) + + image = output.images + image_slice = image[0, -3:, -3:, -1] + expected_slice = [0.3502, 0.3622, 0.3396, 0.3642, 0.3478, 0.3318, 0.35, 0.3348, 0.3297] + + assert image.shape == (1, 512, 512, 3) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + generator = torch.Generator(device=torch_device).manual_seed(seed) + output = sd_pipe( + [prompt], + generator=generator, + guidance_scale=guidance_scale, + num_inference_steps=50, + output_type="np", + width=512, + height=512, + sld_guidance_scale=2000, + sld_warmup_steps=7, + sld_threshold=0.025, + sld_momentum_scale=0.5, + sld_mom_beta=0.7, + ) + + image = output.images + image_slice = image[0, -3:, -3:, -1] + expected_slice = [0.5531, 0.5206, 0.4895, 0.5156, 0.5182, 0.4751, 0.4802, 0.4803, 0.4443] + + assert image.shape == (1, 512, 512, 3) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + def test_nudity_safetychecker_safe_stable_diffusion(self): + sd_pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = ( + "the four horsewomen of the apocalypse, painting by tom of finland, gaston bussiere, craig mullins, j. c." + " leyendecker" + ) + seed = 1044355234 + guidance_scale = 12 + + generator = torch.Generator(device=torch_device).manual_seed(seed) + output = sd_pipe( + [prompt], + generator=generator, + guidance_scale=guidance_scale, + num_inference_steps=50, + output_type="np", + width=512, + height=512, + sld_guidance_scale=0, + ) + + image = output.images + image_slice = image[0, -3:, -3:, -1] + expected_slice = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]) + + assert image.shape == (1, 512, 512, 3) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-7 + + generator = torch.Generator(device=torch_device).manual_seed(seed) + output = sd_pipe( + [prompt], + generator=generator, + guidance_scale=guidance_scale, + num_inference_steps=50, + output_type="np", + width=512, + height=512, + sld_guidance_scale=2000, + sld_warmup_steps=7, + sld_threshold=0.025, + sld_momentum_scale=0.5, + sld_mom_beta=0.7, + ) + + image = output.images + image_slice = image[0, -3:, -3:, -1] + expected_slice = np.array([0.5818, 0.6285, 0.6835, 0.6019, 0.625, 0.6754, 0.6096, 0.6334, 0.6561]) + assert image.shape == (1, 512, 512, 3) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 diff --git a/tests/pipelines/versatile_diffusion/__init__.py b/tests/pipelines/versatile_diffusion/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/versatile_diffusion/test_versatile_diffusion_dual_guided.py b/tests/pipelines/versatile_diffusion/test_versatile_diffusion_dual_guided.py new file mode 100644 index 000000000000..9fb6ca522f5d --- /dev/null +++ b/tests/pipelines/versatile_diffusion/test_versatile_diffusion_dual_guided.py @@ -0,0 +1,112 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import tempfile +import unittest + +import numpy as np +import torch + +from diffusers import VersatileDiffusionDualGuidedPipeline +from diffusers.utils.testing_utils import load_image, require_torch_gpu, slow, torch_device + +from ...test_pipelines_common import PipelineTesterMixin + + +torch.backends.cuda.matmul.allow_tf32 = False + + +class VersatileDiffusionDualGuidedPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pass + + +@slow +@require_torch_gpu +class VersatileDiffusionDualGuidedPipelineIntegrationTests(unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_remove_unused_weights_save_load(self): + pipe = VersatileDiffusionDualGuidedPipeline.from_pretrained("shi-labs/versatile-diffusion") + # remove text_unet + pipe.remove_unused_weights() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + second_prompt = load_image( + "https://raw.githubusercontent.com/SHI-Labs/Versatile-Diffusion/master/assets/benz.jpg" + ) + + generator = torch.Generator(device=torch_device).manual_seed(0) + image = pipe( + prompt="first prompt", + image=second_prompt, + text_to_image_strength=0.75, + generator=generator, + guidance_scale=7.5, + num_inference_steps=2, + output_type="numpy", + ).images + + with tempfile.TemporaryDirectory() as tmpdirname: + pipe.save_pretrained(tmpdirname) + pipe = VersatileDiffusionDualGuidedPipeline.from_pretrained(tmpdirname) + + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + generator = generator.manual_seed(0) + new_image = pipe( + prompt="first prompt", + image=second_prompt, + text_to_image_strength=0.75, + generator=generator, + guidance_scale=7.5, + num_inference_steps=2, + output_type="numpy", + ).images + + assert np.abs(image - new_image).sum() < 1e-5, "Models don't have the same forward pass" + + def test_inference_dual_guided(self): + pipe = VersatileDiffusionDualGuidedPipeline.from_pretrained("shi-labs/versatile-diffusion") + pipe.remove_unused_weights() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + first_prompt = "cyberpunk 2077" + second_prompt = load_image( + "https://raw.githubusercontent.com/SHI-Labs/Versatile-Diffusion/master/assets/benz.jpg" + ) + generator = torch.Generator(device=torch_device).manual_seed(0) + image = pipe( + prompt=first_prompt, + image=second_prompt, + text_to_image_strength=0.75, + generator=generator, + guidance_scale=7.5, + num_inference_steps=50, + output_type="numpy", + ).images + + image_slice = image[0, 253:256, 253:256, -1] + + assert image.shape == (1, 512, 512, 3) + expected_slice = np.array([0.014, 0.0112, 0.0136, 0.0145, 0.0107, 0.0113, 0.0272, 0.0215, 0.0216]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 diff --git a/tests/pipelines/versatile_diffusion/test_versatile_diffusion_image_variation.py b/tests/pipelines/versatile_diffusion/test_versatile_diffusion_image_variation.py new file mode 100644 index 000000000000..1711b752992f --- /dev/null +++ b/tests/pipelines/versatile_diffusion/test_versatile_diffusion_image_variation.py @@ -0,0 +1,58 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch + +from diffusers import VersatileDiffusionImageVariationPipeline +from diffusers.utils.testing_utils import load_image, require_torch_gpu, slow, torch_device + +from ...test_pipelines_common import PipelineTesterMixin + + +torch.backends.cuda.matmul.allow_tf32 = False + + +class VersatileDiffusionImageVariationPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pass + + +@slow +@require_torch_gpu +class VersatileDiffusionImageVariationPipelineIntegrationTests(unittest.TestCase): + def test_inference_image_variations(self): + pipe = VersatileDiffusionImageVariationPipeline.from_pretrained("shi-labs/versatile-diffusion") + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + image_prompt = load_image( + "https://raw.githubusercontent.com/SHI-Labs/Versatile-Diffusion/master/assets/benz.jpg" + ) + generator = torch.Generator(device=torch_device).manual_seed(0) + image = pipe( + image=image_prompt, + generator=generator, + guidance_scale=7.5, + num_inference_steps=50, + output_type="numpy", + ).images + + image_slice = image[0, 253:256, 253:256, -1] + + assert image.shape == (1, 512, 512, 3) + expected_slice = np.array([0.1205, 0.1914, 0.2289, 0.0883, 0.1595, 0.1683, 0.0703, 0.1493, 0.1298]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 diff --git a/tests/pipelines/versatile_diffusion/test_versatile_diffusion_mega.py b/tests/pipelines/versatile_diffusion/test_versatile_diffusion_mega.py new file mode 100644 index 000000000000..ab4580dae1fe --- /dev/null +++ b/tests/pipelines/versatile_diffusion/test_versatile_diffusion_mega.py @@ -0,0 +1,128 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import tempfile +import unittest + +import numpy as np +import torch + +from diffusers import VersatileDiffusionPipeline +from diffusers.utils.testing_utils import load_image, require_torch_gpu, slow, torch_device + +from ...test_pipelines_common import PipelineTesterMixin + + +torch.backends.cuda.matmul.allow_tf32 = False + + +class VersatileDiffusionMegaPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pass + + +@slow +@require_torch_gpu +class VersatileDiffusionMegaPipelineIntegrationTests(unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_from_pretrained_save_pretrained(self): + pipe = VersatileDiffusionPipeline.from_pretrained("shi-labs/versatile-diffusion", torch_dtype=torch.float16) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + prompt_image = load_image( + "https://raw.githubusercontent.com/SHI-Labs/Versatile-Diffusion/master/assets/benz.jpg" + ) + + generator = torch.Generator(device=torch_device).manual_seed(0) + image = pipe.dual_guided( + prompt="first prompt", + image=prompt_image, + text_to_image_strength=0.75, + generator=generator, + guidance_scale=7.5, + num_inference_steps=2, + output_type="numpy", + ).images + + with tempfile.TemporaryDirectory() as tmpdirname: + pipe.save_pretrained(tmpdirname) + pipe = VersatileDiffusionPipeline.from_pretrained(tmpdirname, torch_dtype=torch.float16) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + generator = generator.manual_seed(0) + new_image = pipe.dual_guided( + prompt="first prompt", + image=prompt_image, + text_to_image_strength=0.75, + generator=generator, + guidance_scale=7.5, + num_inference_steps=2, + output_type="numpy", + ).images + + assert np.abs(image - new_image).sum() < 1e-5, "Models don't have the same forward pass" + + def test_inference_dual_guided_then_text_to_image(self): + pipe = VersatileDiffusionPipeline.from_pretrained("shi-labs/versatile-diffusion", torch_dtype=torch.float16) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + prompt = "cyberpunk 2077" + init_image = load_image( + "https://raw.githubusercontent.com/SHI-Labs/Versatile-Diffusion/master/assets/benz.jpg" + ) + generator = torch.Generator(device=torch_device).manual_seed(0) + image = pipe.dual_guided( + prompt=prompt, + image=init_image, + text_to_image_strength=0.75, + generator=generator, + guidance_scale=7.5, + num_inference_steps=50, + output_type="numpy", + ).images + + image_slice = image[0, 253:256, 253:256, -1] + + assert image.shape == (1, 512, 512, 3) + expected_slice = np.array([0.0081, 0.0032, 0.0002, 0.0056, 0.0027, 0.0000, 0.0051, 0.0020, 0.0007]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + prompt = "A painting of a squirrel eating a burger " + generator = torch.Generator(device=torch_device).manual_seed(0) + image = pipe.text_to_image( + prompt=prompt, generator=generator, guidance_scale=7.5, num_inference_steps=50, output_type="numpy" + ).images + + image_slice = image[0, 253:256, 253:256, -1] + + assert image.shape == (1, 512, 512, 3) + expected_slice = np.array([0.0408, 0.0181, 0.0, 0.0388, 0.0046, 0.0461, 0.0411, 0.0, 0.0222]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + image = pipe.image_variation(init_image, generator=generator, output_type="numpy").images + + image_slice = image[0, 253:256, 253:256, -1] + + assert image.shape == (1, 512, 512, 3) + expected_slice = np.array([0.3403, 0.1809, 0.0938, 0.3855, 0.2393, 0.1243, 0.4028, 0.3110, 0.1799]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 diff --git a/tests/pipelines/versatile_diffusion/test_versatile_diffusion_text_to_image.py b/tests/pipelines/versatile_diffusion/test_versatile_diffusion_text_to_image.py new file mode 100644 index 000000000000..027819efee9f --- /dev/null +++ b/tests/pipelines/versatile_diffusion/test_versatile_diffusion_text_to_image.py @@ -0,0 +1,86 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import tempfile +import unittest + +import numpy as np +import torch + +from diffusers import VersatileDiffusionTextToImagePipeline +from diffusers.utils.testing_utils import require_torch_gpu, slow, torch_device + +from ...test_pipelines_common import PipelineTesterMixin + + +torch.backends.cuda.matmul.allow_tf32 = False + + +class VersatileDiffusionTextToImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pass + + +@slow +@require_torch_gpu +class VersatileDiffusionTextToImagePipelineIntegrationTests(unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_remove_unused_weights_save_load(self): + pipe = VersatileDiffusionTextToImagePipeline.from_pretrained("shi-labs/versatile-diffusion") + # remove text_unet + pipe.remove_unused_weights() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger " + generator = torch.Generator(device=torch_device).manual_seed(0) + image = pipe( + prompt=prompt, generator=generator, guidance_scale=7.5, num_inference_steps=2, output_type="numpy" + ).images + + with tempfile.TemporaryDirectory() as tmpdirname: + pipe.save_pretrained(tmpdirname) + pipe = VersatileDiffusionTextToImagePipeline.from_pretrained(tmpdirname) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + generator = generator.manual_seed(0) + new_image = pipe( + prompt=prompt, generator=generator, guidance_scale=7.5, num_inference_steps=2, output_type="numpy" + ).images + + assert np.abs(image - new_image).sum() < 1e-5, "Models don't have the same forward pass" + + def test_inference_text2img(self): + pipe = VersatileDiffusionTextToImagePipeline.from_pretrained("shi-labs/versatile-diffusion") + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger " + generator = torch.Generator(device=torch_device).manual_seed(0) + image = pipe( + prompt=prompt, generator=generator, guidance_scale=7.5, num_inference_steps=50, output_type="numpy" + ).images + + image_slice = image[0, 253:256, 253:256, -1] + + assert image.shape == (1, 512, 512, 3) + expected_slice = np.array([0.0408, 0.0181, 0.0, 0.0388, 0.0046, 0.0461, 0.0411, 0.0, 0.0222]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 diff --git a/tests/pipelines/vq_diffusion/__init__.py b/tests/pipelines/vq_diffusion/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/vq_diffusion/test_vq_diffusion.py b/tests/pipelines/vq_diffusion/test_vq_diffusion.py new file mode 100644 index 000000000000..87e29cbc97de --- /dev/null +++ b/tests/pipelines/vq_diffusion/test_vq_diffusion.py @@ -0,0 +1,228 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import unittest + +import numpy as np +import torch + +from diffusers import Transformer2DModel, VQDiffusionPipeline, VQDiffusionScheduler, VQModel +from diffusers.pipelines.vq_diffusion.pipeline_vq_diffusion import LearnedClassifierFreeSamplingEmbeddings +from diffusers.utils import load_numpy, slow, torch_device +from diffusers.utils.testing_utils import require_torch_gpu +from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer + +from ...test_pipelines_common import PipelineTesterMixin + + +torch.backends.cuda.matmul.allow_tf32 = False + + +class VQDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + @property + def num_embed(self): + return 12 + + @property + def num_embeds_ada_norm(self): + return 12 + + @property + def text_embedder_hidden_size(self): + return 32 + + @property + def dummy_vqvae(self): + torch.manual_seed(0) + model = VQModel( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=3, + num_vq_embeddings=self.num_embed, + vq_embed_dim=3, + ) + return model + + @property + def dummy_tokenizer(self): + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + return tokenizer + + @property + def dummy_text_encoder(self): + torch.manual_seed(0) + config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=self.text_embedder_hidden_size, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + ) + return CLIPTextModel(config) + + @property + def dummy_transformer(self): + torch.manual_seed(0) + + height = 12 + width = 12 + + model_kwargs = { + "attention_bias": True, + "cross_attention_dim": 32, + "attention_head_dim": height * width, + "num_attention_heads": 1, + "num_vector_embeds": self.num_embed, + "num_embeds_ada_norm": self.num_embeds_ada_norm, + "norm_num_groups": 32, + "sample_size": width, + "activation_fn": "geglu-approximate", + } + + model = Transformer2DModel(**model_kwargs) + return model + + def test_vq_diffusion(self): + device = "cpu" + + vqvae = self.dummy_vqvae + text_encoder = self.dummy_text_encoder + tokenizer = self.dummy_tokenizer + transformer = self.dummy_transformer + scheduler = VQDiffusionScheduler(self.num_embed) + learned_classifier_free_sampling_embeddings = LearnedClassifierFreeSamplingEmbeddings(learnable=False) + + pipe = VQDiffusionPipeline( + vqvae=vqvae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + learned_classifier_free_sampling_embeddings=learned_classifier_free_sampling_embeddings, + ) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + prompt = "teddy bear playing in the pool" + + generator = torch.Generator(device=device).manual_seed(0) + output = pipe([prompt], generator=generator, num_inference_steps=2, output_type="np") + image = output.images + + generator = torch.Generator(device=device).manual_seed(0) + image_from_tuple = pipe( + [prompt], generator=generator, output_type="np", return_dict=False, num_inference_steps=2 + )[0] + + image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + + assert image.shape == (1, 24, 24, 3) + + expected_slice = np.array([0.6583, 0.6410, 0.5325, 0.5635, 0.5563, 0.4234, 0.6008, 0.5491, 0.4880]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + + def test_vq_diffusion_classifier_free_sampling(self): + device = "cpu" + + vqvae = self.dummy_vqvae + text_encoder = self.dummy_text_encoder + tokenizer = self.dummy_tokenizer + transformer = self.dummy_transformer + scheduler = VQDiffusionScheduler(self.num_embed) + learned_classifier_free_sampling_embeddings = LearnedClassifierFreeSamplingEmbeddings( + learnable=True, hidden_size=self.text_embedder_hidden_size, length=tokenizer.model_max_length + ) + + pipe = VQDiffusionPipeline( + vqvae=vqvae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + learned_classifier_free_sampling_embeddings=learned_classifier_free_sampling_embeddings, + ) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + prompt = "teddy bear playing in the pool" + + generator = torch.Generator(device=device).manual_seed(0) + output = pipe([prompt], generator=generator, num_inference_steps=2, output_type="np") + image = output.images + + generator = torch.Generator(device=device).manual_seed(0) + image_from_tuple = pipe( + [prompt], generator=generator, output_type="np", return_dict=False, num_inference_steps=2 + )[0] + + image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + + assert image.shape == (1, 24, 24, 3) + + expected_slice = np.array([0.6647, 0.6531, 0.5303, 0.5891, 0.5726, 0.4439, 0.6304, 0.5564, 0.4912]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + + +@slow +@require_torch_gpu +class VQDiffusionPipelineIntegrationTests(unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_vq_diffusion_classifier_free_sampling(self): + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/vq_diffusion/teddy_bear_pool_classifier_free_sampling.npy" + ) + + pipeline = VQDiffusionPipeline.from_pretrained("microsoft/vq-diffusion-ithq") + pipeline = pipeline.to(torch_device) + pipeline.set_progress_bar_config(disable=None) + + generator = torch.Generator(device=torch_device).manual_seed(0) + output = pipeline( + "teddy bear playing in the pool", + num_images_per_prompt=1, + generator=generator, + output_type="np", + ) + + image = output.images[0] + + assert image.shape == (256, 256, 3) + assert np.abs(expected_image - image).max() < 1e-2 diff --git a/tests/repo_utils/test_check_copies.py b/tests/repo_utils/test_check_copies.py new file mode 100644 index 000000000000..65128f68d1ac --- /dev/null +++ b/tests/repo_utils/test_check_copies.py @@ -0,0 +1,120 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import re +import shutil +import sys +import tempfile +import unittest + +import black + + +git_repo_path = os.path.abspath(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) +sys.path.append(os.path.join(git_repo_path, "utils")) + +import check_copies # noqa: E402 + + +# This is the reference code that will be used in the tests. +# If DDPMSchedulerOutput is changed in scheduling_ddpm.py, this code needs to be manually updated. +REFERENCE_CODE = """ \""" + Output class for the scheduler's step function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample (x_{0}) based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + \""" + + prev_sample: torch.FloatTensor + pred_original_sample: Optional[torch.FloatTensor] = None +""" + + +class CopyCheckTester(unittest.TestCase): + def setUp(self): + self.diffusers_dir = tempfile.mkdtemp() + os.makedirs(os.path.join(self.diffusers_dir, "schedulers/")) + check_copies.DIFFUSERS_PATH = self.diffusers_dir + shutil.copy( + os.path.join(git_repo_path, "src/diffusers/schedulers/scheduling_ddpm.py"), + os.path.join(self.diffusers_dir, "schedulers/scheduling_ddpm.py"), + ) + + def tearDown(self): + check_copies.DIFFUSERS_PATH = "src/diffusers" + shutil.rmtree(self.diffusers_dir) + + def check_copy_consistency(self, comment, class_name, class_code, overwrite_result=None): + code = comment + f"\nclass {class_name}(nn.Module):\n" + class_code + if overwrite_result is not None: + expected = comment + f"\nclass {class_name}(nn.Module):\n" + overwrite_result + mode = black.Mode(target_versions={black.TargetVersion.PY35}, line_length=119) + code = black.format_str(code, mode=mode) + fname = os.path.join(self.diffusers_dir, "new_code.py") + with open(fname, "w", newline="\n") as f: + f.write(code) + if overwrite_result is None: + self.assertTrue(len(check_copies.is_copy_consistent(fname)) == 0) + else: + check_copies.is_copy_consistent(f.name, overwrite=True) + with open(fname, "r") as f: + self.assertTrue(f.read(), expected) + + def test_find_code_in_diffusers(self): + code = check_copies.find_code_in_diffusers("schedulers.scheduling_ddpm.DDPMSchedulerOutput") + self.assertEqual(code, REFERENCE_CODE) + + def test_is_copy_consistent(self): + # Base copy consistency + self.check_copy_consistency( + "# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput", + "DDPMSchedulerOutput", + REFERENCE_CODE + "\n", + ) + + # With no empty line at the end + self.check_copy_consistency( + "# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput", + "DDPMSchedulerOutput", + REFERENCE_CODE, + ) + + # Copy consistency with rename + self.check_copy_consistency( + "# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->Test", + "TestSchedulerOutput", + re.sub("DDPM", "Test", REFERENCE_CODE), + ) + + # Copy consistency with a really long name + long_class_name = "TestClassWithAReallyLongNameBecauseSomePeopleLikeThatForSomeReason" + self.check_copy_consistency( + f"# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->{long_class_name}", + f"{long_class_name}SchedulerOutput", + re.sub("Bert", long_class_name, REFERENCE_CODE), + ) + + # Copy consistency with overwrite + self.check_copy_consistency( + "# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->Test", + "TestSchedulerOutput", + REFERENCE_CODE, + overwrite_result=re.sub("DDPM", "Test", REFERENCE_CODE), + ) diff --git a/tests/repo_utils/test_check_dummies.py b/tests/repo_utils/test_check_dummies.py new file mode 100644 index 000000000000..d8fa9ce10547 --- /dev/null +++ b/tests/repo_utils/test_check_dummies.py @@ -0,0 +1,124 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import unittest + + +git_repo_path = os.path.abspath(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) +sys.path.append(os.path.join(git_repo_path, "utils")) + +import check_dummies +from check_dummies import create_dummy_files, create_dummy_object, find_backend, read_init # noqa: E402 + + +# Align TRANSFORMERS_PATH in check_dummies with the current path +check_dummies.PATH_TO_DIFFUSERS = os.path.join(git_repo_path, "src", "diffusers") + + +class CheckDummiesTester(unittest.TestCase): + def test_find_backend(self): + simple_backend = find_backend(" if not is_torch_available():") + self.assertEqual(simple_backend, "torch") + + # backend_with_underscore = find_backend(" if not is_tensorflow_text_available():") + # self.assertEqual(backend_with_underscore, "tensorflow_text") + + double_backend = find_backend(" if not (is_torch_available() and is_transformers_available()):") + self.assertEqual(double_backend, "torch_and_transformers") + + # double_backend_with_underscore = find_backend( + # " if not (is_sentencepiece_available() and is_tensorflow_text_available()):" + # ) + # self.assertEqual(double_backend_with_underscore, "sentencepiece_and_tensorflow_text") + + triple_backend = find_backend( + " if not (is_torch_available() and is_transformers_available() and is_onnx_available()):" + ) + self.assertEqual(triple_backend, "torch_and_transformers_and_onnx") + + def test_read_init(self): + objects = read_init() + # We don't assert on the exact list of keys to allow for smooth grow of backend-specific objects + self.assertIn("torch", objects) + self.assertIn("torch_and_transformers", objects) + self.assertIn("flax_and_transformers", objects) + self.assertIn("torch_and_transformers_and_onnx", objects) + + # Likewise, we can't assert on the exact content of a key + self.assertIn("UNet2DModel", objects["torch"]) + self.assertIn("FlaxUNet2DConditionModel", objects["flax"]) + self.assertIn("StableDiffusionPipeline", objects["torch_and_transformers"]) + self.assertIn("FlaxStableDiffusionPipeline", objects["flax_and_transformers"]) + self.assertIn("LMSDiscreteScheduler", objects["torch_and_scipy"]) + self.assertIn("OnnxStableDiffusionPipeline", objects["torch_and_transformers_and_onnx"]) + + def test_create_dummy_object(self): + dummy_constant = create_dummy_object("CONSTANT", "'torch'") + self.assertEqual(dummy_constant, "\nCONSTANT = None\n") + + dummy_function = create_dummy_object("function", "'torch'") + self.assertEqual( + dummy_function, "\ndef function(*args, **kwargs):\n requires_backends(function, 'torch')\n" + ) + + expected_dummy_class = """ +class FakeClass(metaclass=DummyObject): + _backends = 'torch' + + def __init__(self, *args, **kwargs): + requires_backends(self, 'torch') + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, 'torch') + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, 'torch') +""" + dummy_class = create_dummy_object("FakeClass", "'torch'") + self.assertEqual(dummy_class, expected_dummy_class) + + def test_create_dummy_files(self): + expected_dummy_pytorch_file = """# This file is autogenerated by the command `make fix-copies`, do not edit. +# flake8: noqa + +from ..utils import DummyObject, requires_backends + + +CONSTANT = None + + +def function(*args, **kwargs): + requires_backends(function, ["torch"]) + + +class FakeClass(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) +""" + dummy_files = create_dummy_files({"torch": ["CONSTANT", "function", "FakeClass"]}) + self.assertEqual(dummy_files["torch"], expected_dummy_pytorch_file) diff --git a/tests/test_config.py b/tests/test_config.py old mode 100755 new mode 100644 index ab5b1e86bca8..2a021c4ced5f --- a/tests/test_config.py +++ b/tests/test_config.py @@ -16,7 +16,18 @@ import tempfile import unittest +from diffusers import ( + DDIMScheduler, + DDPMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + PNDMScheduler, + logging, +) from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import deprecate +from diffusers.utils.testing_utils import CaptureLogger class SampleObject(ConfigMixin): @@ -34,10 +45,41 @@ def __init__( pass +class SampleObject2(ConfigMixin): + config_name = "config.json" + + @register_to_config + def __init__( + self, + a=2, + b=5, + c=(2, 5), + d="for diffusion", + f=[1, 3], + ): + pass + + +class SampleObject3(ConfigMixin): + config_name = "config.json" + + @register_to_config + def __init__( + self, + a=2, + b=5, + c=(2, 5), + d="for diffusion", + e=[1, 3], + f=[1, 3], + ): + pass + + class ConfigTester(unittest.TestCase): def test_load_not_from_mixin(self): with self.assertRaises(ValueError): - ConfigMixin.from_config("dummy_path") + ConfigMixin.load_config("dummy_path") def test_register_to_config(self): obj = SampleObject() @@ -87,7 +129,7 @@ def test_save_load(self): with tempfile.TemporaryDirectory() as tmpdirname: obj.save_config(tmpdirname) - new_obj = SampleObject.from_config(tmpdirname) + new_obj = SampleObject.from_config(SampleObject.load_config(tmpdirname)) new_config = new_obj.config # unfreeze configs @@ -97,3 +139,96 @@ def test_save_load(self): assert config.pop("c") == (2, 5) # instantiated as tuple assert new_config.pop("c") == [2, 5] # saved & loaded as list because of json assert config == new_config + + def test_load_ddim_from_pndm(self): + logger = logging.get_logger("diffusers.configuration_utils") + + with CaptureLogger(logger) as cap_logger: + ddim = DDIMScheduler.from_pretrained( + "hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler" + ) + + assert ddim.__class__ == DDIMScheduler + # no warning should be thrown + assert cap_logger.out == "" + + def test_load_euler_from_pndm(self): + logger = logging.get_logger("diffusers.configuration_utils") + + with CaptureLogger(logger) as cap_logger: + euler = EulerDiscreteScheduler.from_pretrained( + "hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler" + ) + + assert euler.__class__ == EulerDiscreteScheduler + # no warning should be thrown + assert cap_logger.out == "" + + def test_load_euler_ancestral_from_pndm(self): + logger = logging.get_logger("diffusers.configuration_utils") + + with CaptureLogger(logger) as cap_logger: + euler = EulerAncestralDiscreteScheduler.from_pretrained( + "hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler" + ) + + assert euler.__class__ == EulerAncestralDiscreteScheduler + # no warning should be thrown + assert cap_logger.out == "" + + def test_load_pndm(self): + logger = logging.get_logger("diffusers.configuration_utils") + + with CaptureLogger(logger) as cap_logger: + pndm = PNDMScheduler.from_pretrained( + "hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler" + ) + + assert pndm.__class__ == PNDMScheduler + # no warning should be thrown + assert cap_logger.out == "" + + def test_overwrite_config_on_load(self): + logger = logging.get_logger("diffusers.configuration_utils") + + with CaptureLogger(logger) as cap_logger: + ddpm = DDPMScheduler.from_pretrained( + "hf-internal-testing/tiny-stable-diffusion-torch", + subfolder="scheduler", + prediction_type="sample", + beta_end=8, + ) + + with CaptureLogger(logger) as cap_logger_2: + ddpm_2 = DDPMScheduler.from_pretrained("google/ddpm-celebahq-256", beta_start=88) + + with CaptureLogger(logger) as cap_logger: + deprecate("remove this case", "0.10.0", "remove") + ddpm_3 = DDPMScheduler.from_pretrained( + "hf-internal-testing/tiny-stable-diffusion-torch", + subfolder="scheduler", + predict_epsilon=False, + beta_end=8, + ) + + assert ddpm.__class__ == DDPMScheduler + assert ddpm.config.prediction_type == "sample" + assert ddpm.config.beta_end == 8 + assert ddpm_2.config.beta_start == 88 + assert ddpm_3.config.prediction_type == "sample" + + # no warning should be thrown + assert cap_logger.out == "" + assert cap_logger_2.out == "" + + def test_load_dpmsolver(self): + logger = logging.get_logger("diffusers.configuration_utils") + + with CaptureLogger(logger) as cap_logger: + dpm = DPMSolverMultistepScheduler.from_pretrained( + "hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler" + ) + + assert dpm.__class__ == DPMSolverMultistepScheduler + # no warning should be thrown + assert cap_logger.out == "" diff --git a/tests/test_layers_utils.py b/tests/test_layers_utils.py index cf531fbf3fd3..911ec548b3bd 100755 --- a/tests/test_layers_utils.py +++ b/tests/test_layers_utils.py @@ -18,8 +18,9 @@ import numpy as np import torch +from torch import nn -from diffusers.models.attention import AttentionBlock, SpatialTransformer +from diffusers.models.attention import GEGLU, AdaLayerNorm, ApproximateGELU, AttentionBlock, Transformer2DModel from diffusers.models.embeddings import get_timestep_embedding from diffusers.models.resnet import Downsample2D, Upsample2D from diffusers.utils import torch_device @@ -235,7 +236,7 @@ def test_attention_block_default(self): num_head_channels=1, rescale_output_factor=1.0, eps=1e-6, - num_groups=32, + norm_num_groups=32, ).to(torch_device) with torch.no_grad(): attention_scores = attentionBlock(sample) @@ -259,7 +260,7 @@ def test_attention_block_sd(self): channels=512, rescale_output_factor=1.0, eps=1e-6, - num_groups=32, + norm_num_groups=32, ).to(torch_device) with torch.no_grad(): attention_scores = attentionBlock(sample) @@ -273,22 +274,22 @@ def test_attention_block_sd(self): assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) -class SpatialTransformerTests(unittest.TestCase): +class Transformer2DModelTests(unittest.TestCase): def test_spatial_transformer_default(self): torch.manual_seed(0) if torch.cuda.is_available(): torch.cuda.manual_seed_all(0) sample = torch.randn(1, 32, 64, 64).to(torch_device) - spatial_transformer_block = SpatialTransformer( + spatial_transformer_block = Transformer2DModel( in_channels=32, - n_heads=1, - d_head=32, + num_attention_heads=1, + attention_head_dim=32, dropout=0.0, - context_dim=None, + cross_attention_dim=None, ).to(torch_device) with torch.no_grad(): - attention_scores = spatial_transformer_block(sample) + attention_scores = spatial_transformer_block(sample).sample assert attention_scores.shape == (1, 32, 64, 64) output_slice = attention_scores[0, -1, -3:, -3:] @@ -298,22 +299,22 @@ def test_spatial_transformer_default(self): ) assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) - def test_spatial_transformer_context_dim(self): + def test_spatial_transformer_cross_attention_dim(self): torch.manual_seed(0) if torch.cuda.is_available(): torch.cuda.manual_seed_all(0) sample = torch.randn(1, 64, 64, 64).to(torch_device) - spatial_transformer_block = SpatialTransformer( + spatial_transformer_block = Transformer2DModel( in_channels=64, - n_heads=2, - d_head=32, + num_attention_heads=2, + attention_head_dim=32, dropout=0.0, - context_dim=64, + cross_attention_dim=64, ).to(torch_device) with torch.no_grad(): context = torch.randn(1, 4, 64).to(torch_device) - attention_scores = spatial_transformer_block(sample, context) + attention_scores = spatial_transformer_block(sample, context).sample assert attention_scores.shape == (1, 64, 64, 64) output_slice = attention_scores[0, -1, -3:, -3:] @@ -323,6 +324,44 @@ def test_spatial_transformer_context_dim(self): ) assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) + def test_spatial_transformer_timestep(self): + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + num_embeds_ada_norm = 5 + + sample = torch.randn(1, 64, 64, 64).to(torch_device) + spatial_transformer_block = Transformer2DModel( + in_channels=64, + num_attention_heads=2, + attention_head_dim=32, + dropout=0.0, + cross_attention_dim=64, + num_embeds_ada_norm=num_embeds_ada_norm, + ).to(torch_device) + with torch.no_grad(): + timestep_1 = torch.tensor(1, dtype=torch.long).to(torch_device) + timestep_2 = torch.tensor(2, dtype=torch.long).to(torch_device) + attention_scores_1 = spatial_transformer_block(sample, timestep=timestep_1).sample + attention_scores_2 = spatial_transformer_block(sample, timestep=timestep_2).sample + + assert attention_scores_1.shape == (1, 64, 64, 64) + assert attention_scores_2.shape == (1, 64, 64, 64) + + output_slice_1 = attention_scores_1[0, -1, -3:, -3:] + output_slice_2 = attention_scores_2[0, -1, -3:, -3:] + + expected_slice_1 = torch.tensor( + [-0.1874, -0.9704, -1.4290, -1.3357, 1.5138, 0.3036, -0.0976, -1.1667, 0.1283], device=torch_device + ) + expected_slice_2 = torch.tensor( + [-0.3493, -1.0924, -1.6161, -1.5016, 1.4245, 0.1367, -0.2526, -1.3109, -0.0547], device=torch_device + ) + + assert torch.allclose(output_slice_1.flatten(), expected_slice_1, atol=1e-3) + assert torch.allclose(output_slice_2.flatten(), expected_slice_2, atol=1e-3) + def test_spatial_transformer_dropout(self): torch.manual_seed(0) if torch.cuda.is_available(): @@ -330,18 +369,18 @@ def test_spatial_transformer_dropout(self): sample = torch.randn(1, 32, 64, 64).to(torch_device) spatial_transformer_block = ( - SpatialTransformer( + Transformer2DModel( in_channels=32, - n_heads=2, - d_head=16, + num_attention_heads=2, + attention_head_dim=16, dropout=0.3, - context_dim=None, + cross_attention_dim=None, ) .to(torch_device) .eval() ) with torch.no_grad(): - attention_scores = spatial_transformer_block(sample) + attention_scores = spatial_transformer_block(sample).sample assert attention_scores.shape == (1, 32, 64, 64) output_slice = attention_scores[0, -1, -3:, -3:] @@ -350,3 +389,107 @@ def test_spatial_transformer_dropout(self): [-1.2448, -0.0190, -0.9471, -1.5140, 0.7069, -1.0144, -2.1077, 0.9099, -1.0091], device=torch_device ) assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) + + @unittest.skipIf(torch_device == "mps", "MPS does not support float64") + def test_spatial_transformer_discrete(self): + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + num_embed = 5 + + sample = torch.randint(0, num_embed, (1, 32)).to(torch_device) + spatial_transformer_block = ( + Transformer2DModel( + num_attention_heads=1, + attention_head_dim=32, + num_vector_embeds=num_embed, + sample_size=16, + ) + .to(torch_device) + .eval() + ) + + with torch.no_grad(): + attention_scores = spatial_transformer_block(sample).sample + + assert attention_scores.shape == (1, num_embed - 1, 32) + + output_slice = attention_scores[0, -2:, -3:] + + expected_slice = torch.tensor([-0.8957, -1.8370, -1.3390, -0.9152, -0.5187, -1.1702], device=torch_device) + assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) + + def test_spatial_transformer_default_norm_layers(self): + spatial_transformer_block = Transformer2DModel(num_attention_heads=1, attention_head_dim=32, in_channels=32) + + assert spatial_transformer_block.transformer_blocks[0].norm1.__class__ == nn.LayerNorm + assert spatial_transformer_block.transformer_blocks[0].norm2.__class__ == nn.LayerNorm + assert spatial_transformer_block.transformer_blocks[0].norm3.__class__ == nn.LayerNorm + + def test_spatial_transformer_ada_norm_layers(self): + spatial_transformer_block = Transformer2DModel( + num_attention_heads=1, + attention_head_dim=32, + in_channels=32, + num_embeds_ada_norm=5, + ) + + assert spatial_transformer_block.transformer_blocks[0].norm1.__class__ == AdaLayerNorm + assert spatial_transformer_block.transformer_blocks[0].norm2.__class__ == AdaLayerNorm + assert spatial_transformer_block.transformer_blocks[0].norm3.__class__ == nn.LayerNorm + + def test_spatial_transformer_default_ff_layers(self): + spatial_transformer_block = Transformer2DModel( + num_attention_heads=1, + attention_head_dim=32, + in_channels=32, + ) + + assert spatial_transformer_block.transformer_blocks[0].ff.net[0].__class__ == GEGLU + assert spatial_transformer_block.transformer_blocks[0].ff.net[1].__class__ == nn.Dropout + assert spatial_transformer_block.transformer_blocks[0].ff.net[2].__class__ == nn.Linear + + dim = 32 + inner_dim = 128 + + # First dimension change + assert spatial_transformer_block.transformer_blocks[0].ff.net[0].proj.in_features == dim + # NOTE: inner_dim * 2 because GEGLU + assert spatial_transformer_block.transformer_blocks[0].ff.net[0].proj.out_features == inner_dim * 2 + + # Second dimension change + assert spatial_transformer_block.transformer_blocks[0].ff.net[2].in_features == inner_dim + assert spatial_transformer_block.transformer_blocks[0].ff.net[2].out_features == dim + + def test_spatial_transformer_geglu_approx_ff_layers(self): + spatial_transformer_block = Transformer2DModel( + num_attention_heads=1, + attention_head_dim=32, + in_channels=32, + activation_fn="geglu-approximate", + ) + + assert spatial_transformer_block.transformer_blocks[0].ff.net[0].__class__ == ApproximateGELU + assert spatial_transformer_block.transformer_blocks[0].ff.net[1].__class__ == nn.Dropout + assert spatial_transformer_block.transformer_blocks[0].ff.net[2].__class__ == nn.Linear + + dim = 32 + inner_dim = 128 + + # First dimension change + assert spatial_transformer_block.transformer_blocks[0].ff.net[0].proj.in_features == dim + assert spatial_transformer_block.transformer_blocks[0].ff.net[0].proj.out_features == inner_dim + + # Second dimension change + assert spatial_transformer_block.transformer_blocks[0].ff.net[2].in_features == inner_dim + assert spatial_transformer_block.transformer_blocks[0].ff.net[2].out_features == dim + + def test_spatial_transformer_attention_bias(self): + spatial_transformer_block = Transformer2DModel( + num_attention_heads=1, attention_head_dim=32, in_channels=32, attention_bias=True + ) + + assert spatial_transformer_block.transformer_blocks[0].attn1.to_q.bias is not None + assert spatial_transformer_block.transformer_blocks[0].attn1.to_k.bias is not None + assert spatial_transformer_block.transformer_blocks[0].attn1.to_v.bias is not None diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index eabe6ada9fea..cad1887f4df8 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -130,7 +130,7 @@ def test_forward_signature(self): expected_arg_names = ["sample", "timestep"] self.assertListEqual(arg_names[:2], expected_arg_names) - def test_model_from_config(self): + def test_model_from_pretrained(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() model = self.model_class(**init_dict) @@ -140,8 +140,8 @@ def test_model_from_config(self): # test if the model can be loaded from the config # and has all the expected shape with tempfile.TemporaryDirectory() as tmpdirname: - model.save_config(tmpdirname) - new_model = self.model_class.from_config(tmpdirname) + model.save_pretrained(tmpdirname) + new_model = self.model_class.from_pretrained(tmpdirname) new_model.to(torch_device) new_model.eval() @@ -265,3 +265,23 @@ def test_enable_disable_gradient_checkpointing(self): # check disable works model.disable_gradient_checkpointing() self.assertFalse(model.is_gradient_checkpointing) + + def test_deprecated_kwargs(self): + has_kwarg_in_model_class = "kwargs" in inspect.signature(self.model_class.__init__).parameters + has_deprecated_kwarg = len(self.model_class._deprecated_kwargs) > 0 + + if has_kwarg_in_model_class and not has_deprecated_kwarg: + raise ValueError( + f"{self.model_class} has `**kwargs` in its __init__ method but has not defined any deprecated kwargs" + " under the `_deprecated_kwargs` class attribute. Make sure to either remove `**kwargs` if there are" + " no deprecated arguments or add the deprecated argument with `_deprecated_kwargs =" + " []`" + ) + + if not has_kwarg_in_model_class and has_deprecated_kwarg: + raise ValueError( + f"{self.model_class} doesn't have `**kwargs` in its __init__ method but has defined deprecated kwargs" + " under the `_deprecated_kwargs` class attribute. Make sure to either add the `**kwargs` argument to" + f" {self.model_class}.__init__ if there are deprecated arguments or remove the deprecated argument" + " from `_deprecated_kwargs = []`" + ) diff --git a/tests/test_modeling_common_flax.py b/tests/test_modeling_common_flax.py index 61849b22318f..8945aed7c93f 100644 --- a/tests/test_modeling_common_flax.py +++ b/tests/test_modeling_common_flax.py @@ -1,3 +1,5 @@ +import inspect + from diffusers.utils import is_flax_available from diffusers.utils.testing_utils import require_flax @@ -42,3 +44,23 @@ def test_forward_with_norm_groups(self): self.assertIsNotNone(output) expected_shape = inputs_dict["sample"].shape self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + + def test_deprecated_kwargs(self): + has_kwarg_in_model_class = "kwargs" in inspect.signature(self.model_class.__init__).parameters + has_deprecated_kwarg = len(self.model_class._deprecated_kwargs) > 0 + + if has_kwarg_in_model_class and not has_deprecated_kwarg: + raise ValueError( + f"{self.model_class} has `**kwargs` in its __init__ method but has not defined any deprecated kwargs" + " under the `_deprecated_kwargs` class attribute. Make sure to either remove `**kwargs` if there are" + " no deprecated arguments or add the deprecated argument with `_deprecated_kwargs =" + " []`" + ) + + if not has_kwarg_in_model_class and has_deprecated_kwarg: + raise ValueError( + f"{self.model_class} doesn't have `**kwargs` in its __init__ method but has defined deprecated kwargs" + " under the `_deprecated_kwargs` class attribute. Make sure to either add the `**kwargs` argument to" + f" {self.model_class}.__init__ if there are deprecated arguments or remove the deprecated argument" + " from `_deprecated_kwargs = []`" + ) diff --git a/tests/test_models_vae.py b/tests/test_models_vae.py deleted file mode 100644 index 9fb7e8ea3bb7..000000000000 --- a/tests/test_models_vae.py +++ /dev/null @@ -1,118 +0,0 @@ -# coding=utf-8 -# Copyright 2022 HuggingFace Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -import torch - -from diffusers import AutoencoderKL -from diffusers.modeling_utils import ModelMixin -from diffusers.utils import floats_tensor, torch_device - -from .test_modeling_common import ModelTesterMixin - - -torch.backends.cuda.matmul.allow_tf32 = False - - -class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase): - model_class = AutoencoderKL - - @property - def dummy_input(self): - batch_size = 4 - num_channels = 3 - sizes = (32, 32) - - image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) - - return {"sample": image} - - @property - def input_shape(self): - return (3, 32, 32) - - @property - def output_shape(self): - return (3, 32, 32) - - def prepare_init_args_and_inputs_for_common(self): - init_dict = { - "block_out_channels": [32, 64], - "in_channels": 3, - "out_channels": 3, - "down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"], - "up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"], - "latent_channels": 4, - } - inputs_dict = self.dummy_input - return init_dict, inputs_dict - - def test_forward_signature(self): - pass - - def test_training(self): - pass - - def test_from_pretrained_hub(self): - model, loading_info = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy", output_loading_info=True) - self.assertIsNotNone(model) - self.assertEqual(len(loading_info["missing_keys"]), 0) - - model.to(torch_device) - image = model(**self.dummy_input) - - assert image is not None, "Make sure output is not None" - - def test_output_pretrained(self): - model = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy") - model = model.to(torch_device) - model.eval() - - # One-time warmup pass (see #372) - if torch_device == "mps" and isinstance(model, ModelMixin): - image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size) - image = image.to(torch_device) - with torch.no_grad(): - _ = model(image, sample_posterior=True).sample - generator = torch.manual_seed(0) - else: - generator = torch.Generator(device=torch_device).manual_seed(0) - - image = torch.randn( - 1, - model.config.in_channels, - model.config.sample_size, - model.config.sample_size, - generator=torch.manual_seed(0), - ) - image = image.to(torch_device) - with torch.no_grad(): - output = model(image, sample_posterior=True, generator=generator).sample - - output_slice = output[0, -1, -3:, -3:].flatten().cpu() - - # Since the VAE Gaussian prior's generator is seeded on the appropriate device, - # the expected output slices are not the same for CPU and GPU. - if torch_device in ("mps", "cpu"): - expected_output_slice = torch.tensor( - [-0.1352, 0.0878, 0.0419, -0.0818, -0.1069, 0.0688, -0.1458, -0.4446, -0.0026] - ) - else: - expected_output_slice = torch.tensor( - [-0.2421, 0.4642, 0.2507, -0.0438, 0.0682, 0.3160, -0.2018, -0.0727, 0.2485] - ) - - self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2)) diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index 69a45c4247aa..0aad9de8be02 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -14,46 +14,40 @@ # limitations under the License. import gc +import json import os import random +import shutil import tempfile -import tracemalloc import unittest import numpy as np import torch -import accelerate import PIL -import transformers from diffusers import ( AutoencoderKL, DDIMPipeline, DDIMScheduler, DDPMPipeline, DDPMScheduler, - KarrasVePipeline, - KarrasVeScheduler, - LDMPipeline, - LDMTextToImagePipeline, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, LMSDiscreteScheduler, - PNDMPipeline, PNDMScheduler, - ScoreSdeVePipeline, - ScoreSdeVeScheduler, StableDiffusionImg2ImgPipeline, - StableDiffusionInpaintPipeline, - StableDiffusionOnnxPipeline, + StableDiffusionInpaintPipelineLegacy, StableDiffusionPipeline, UNet2DConditionModel, UNet2DModel, - VQModel, + logging, ) from diffusers.pipeline_utils import DiffusionPipeline from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME -from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, load_image, slow, torch_device -from diffusers.utils.testing_utils import get_tests_dir -from packaging import version +from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, slow, torch_device +from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir, require_torch_gpu +from parameterized import parameterized from PIL import Image from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextConfig, CLIPTextModel, CLIPTokenizer @@ -84,11 +78,106 @@ def test_progress_bar(capsys): assert captured.err == "", "Progress bar should be disabled" +class DownloadTests(unittest.TestCase): + def test_download_only_pytorch(self): + with tempfile.TemporaryDirectory() as tmpdirname: + # pipeline has Flax weights + _ = DiffusionPipeline.from_pretrained( + "hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None, cache_dir=tmpdirname + ) + + all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname, os.listdir(tmpdirname)[0], "snapshots"))] + files = [item for sublist in all_root_files for item in sublist] + + # None of the downloaded files should be a flax file even if we have some here: + # https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe/blob/main/unet/diffusion_flax_model.msgpack + assert not any(f.endswith(".msgpack") for f in files) + + def test_download_no_safety_checker(self): + prompt = "hello" + pipe = StableDiffusionPipeline.from_pretrained( + "hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None + ) + pipe = pipe.to(torch_device) + if torch_device == "mps": + # device type MPS is not supported for torch.Generator() api. + generator = torch.manual_seed(0) + else: + generator = torch.Generator(device=torch_device).manual_seed(0) + out = pipe(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images + + pipe_2 = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch") + pipe_2 = pipe_2.to(torch_device) + if torch_device == "mps": + # device type MPS is not supported for torch.Generator() api. + generator = torch.manual_seed(0) + else: + generator = torch.Generator(device=torch_device).manual_seed(0) + out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images + + assert np.max(np.abs(out - out_2)) < 1e-3 + + def test_load_no_safety_checker_explicit_locally(self): + prompt = "hello" + pipe = StableDiffusionPipeline.from_pretrained( + "hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None + ) + pipe = pipe.to(torch_device) + if torch_device == "mps": + # device type MPS is not supported for torch.Generator() api. + generator = torch.manual_seed(0) + else: + generator = torch.Generator(device=torch_device).manual_seed(0) + out = pipe(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images + + with tempfile.TemporaryDirectory() as tmpdirname: + pipe.save_pretrained(tmpdirname) + pipe_2 = StableDiffusionPipeline.from_pretrained(tmpdirname, safety_checker=None) + pipe_2 = pipe_2.to(torch_device) + + if torch_device == "mps": + # device type MPS is not supported for torch.Generator() api. + generator = torch.manual_seed(0) + else: + generator = torch.Generator(device=torch_device).manual_seed(0) + + out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images + + assert np.max(np.abs(out - out_2)) < 1e-3 + + def test_load_no_safety_checker_default_locally(self): + prompt = "hello" + pipe = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch") + pipe = pipe.to(torch_device) + if torch_device == "mps": + # device type MPS is not supported for torch.Generator() api. + generator = torch.manual_seed(0) + else: + generator = torch.Generator(device=torch_device).manual_seed(0) + out = pipe(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images + + with tempfile.TemporaryDirectory() as tmpdirname: + pipe.save_pretrained(tmpdirname) + pipe_2 = StableDiffusionPipeline.from_pretrained(tmpdirname) + pipe_2 = pipe_2.to(torch_device) + + if torch_device == "mps": + # device type MPS is not supported for torch.Generator() api. + generator = torch.manual_seed(0) + else: + generator = torch.Generator(device=torch_device).manual_seed(0) + + out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images + + assert np.max(np.abs(out - out_2)) < 1e-3 + + class CustomPipelineTests(unittest.TestCase): def test_load_custom_pipeline(self): pipeline = DiffusionPipeline.from_pretrained( "google/ddpm-cifar10-32", custom_pipeline="hf-internal-testing/diffusers-dummy-pipeline" ) + pipeline = pipeline.to(torch_device) # NOTE that `"CustomPipeline"` is not a class that is defined in this library, but solely on the Hub # under https://huggingface.co/hf-internal-testing/diffusers-dummy-pipeline/blob/main/pipeline.py#L24 assert pipeline.__class__.__name__ == "CustomPipeline" @@ -97,17 +186,34 @@ def test_run_custom_pipeline(self): pipeline = DiffusionPipeline.from_pretrained( "google/ddpm-cifar10-32", custom_pipeline="hf-internal-testing/diffusers-dummy-pipeline" ) + pipeline = pipeline.to(torch_device) images, output_str = pipeline(num_inference_steps=2, output_type="np") assert images[0].shape == (1, 32, 32, 3) + # compare output to https://huggingface.co/hf-internal-testing/diffusers-dummy-pipeline/blob/main/pipeline.py#L102 assert output_str == "This is a test" - def test_local_custom_pipeline(self): + def test_local_custom_pipeline_repo(self): + local_custom_pipeline_path = get_tests_dir("fixtures/custom_pipeline") + pipeline = DiffusionPipeline.from_pretrained( + "google/ddpm-cifar10-32", custom_pipeline=local_custom_pipeline_path + ) + pipeline = pipeline.to(torch_device) + images, output_str = pipeline(num_inference_steps=2, output_type="np") + + assert pipeline.__class__.__name__ == "CustomLocalPipeline" + assert images[0].shape == (1, 32, 32, 3) + # compare to https://github.com/huggingface/diffusers/blob/main/tests/fixtures/custom_pipeline/pipeline.py#L102 + assert output_str == "This is a local test" + + def test_local_custom_pipeline_file(self): local_custom_pipeline_path = get_tests_dir("fixtures/custom_pipeline") + local_custom_pipeline_path = os.path.join(local_custom_pipeline_path, "what_ever.py") pipeline = DiffusionPipeline.from_pretrained( "google/ddpm-cifar10-32", custom_pipeline=local_custom_pipeline_path ) + pipeline = pipeline.to(torch_device) images, output_str = pipeline(num_inference_steps=2, output_type="np") assert pipeline.__class__.__name__ == "CustomLocalPipeline" @@ -116,7 +222,7 @@ def test_local_custom_pipeline(self): assert output_str == "This is a local test" @slow - @unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU") + @require_torch_gpu def test_load_pipeline_from_git(self): clip_model_id = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K" @@ -143,13 +249,6 @@ def test_load_pipeline_from_git(self): class PipelineFastTests(unittest.TestCase): - def tearDown(self): - # clean up the VRAM after each test - super().tearDown() - gc.collect() - torch.cuda.empty_cache() - - @property def dummy_image(self): batch_size = 1 num_channels = 3 @@ -158,13 +257,12 @@ def dummy_image(self): image = floats_tensor((batch_size, num_channels) + sizes, rng=random.Random(0)).to(torch_device) return image - @property - def dummy_uncond_unet(self): + def dummy_uncond_unet(self, sample_size=32): torch.manual_seed(0) model = UNet2DModel( block_out_channels=(32, 64), layers_per_block=2, - sample_size=32, + sample_size=sample_size, in_channels=3, out_channels=3, down_block_types=("DownBlock2D", "AttnDownBlock2D"), @@ -172,13 +270,12 @@ def dummy_uncond_unet(self): ) return model - @property - def dummy_cond_unet(self): + def dummy_cond_unet(self, sample_size=32): torch.manual_seed(0) model = UNet2DConditionModel( block_out_channels=(32, 64), layers_per_block=2, - sample_size=32, + sample_size=sample_size, in_channels=4, out_channels=4, down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), @@ -187,19 +284,6 @@ def dummy_cond_unet(self): ) return model - @property - def dummy_vq_model(self): - torch.manual_seed(0) - model = VQModel( - block_out_channels=[32, 64], - in_channels=3, - out_channels=3, - down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], - up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], - latent_channels=3, - ) - return model - @property def dummy_vae(self): torch.manual_seed(0) @@ -229,13 +313,6 @@ def dummy_text_encoder(self): ) return CLIPTextModel(config) - @property - def dummy_safety_checker(self): - def check(images, *args, **kwargs): - return images, [False] * len(images) - - return check - @property def dummy_extractor(self): def extract(*args, **kwargs): @@ -251,1104 +328,315 @@ def to(self, device): return extract - def test_ddim(self): - unet = self.dummy_uncond_unet - scheduler = DDIMScheduler() - - ddpm = DDIMPipeline(unet=unet, scheduler=scheduler) - ddpm.to(torch_device) - ddpm.set_progress_bar_config(disable=None) - - # Warmup pass when using mps (see #372) - if torch_device == "mps": - _ = ddpm(num_inference_steps=1) - - generator = torch.manual_seed(0) - image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images - - generator = torch.manual_seed(0) - image_from_tuple = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", return_dict=False)[0] - - image_slice = image[0, -3:, -3:, -1] - image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] - - assert image.shape == (1, 32, 32, 3) - expected_slice = np.array( - [1.000e00, 5.717e-01, 4.717e-01, 1.000e00, 0.000e00, 1.000e00, 3.000e-04, 0.000e00, 9.000e-04] - ) - tolerance = 1e-2 if torch_device != "mps" else 3e-2 - assert np.abs(image_slice.flatten() - expected_slice).max() < tolerance - assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < tolerance - - def test_pndm_cifar10(self): - unet = self.dummy_uncond_unet - scheduler = PNDMScheduler() - - pndm = PNDMPipeline(unet=unet, scheduler=scheduler) - pndm.to(torch_device) - pndm.set_progress_bar_config(disable=None) - - generator = torch.manual_seed(0) - image = pndm(generator=generator, num_inference_steps=20, output_type="numpy").images - - generator = torch.manual_seed(0) - image_from_tuple = pndm(generator=generator, num_inference_steps=20, output_type="numpy", return_dict=False)[0] - - image_slice = image[0, -3:, -3:, -1] - image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] - - assert image.shape == (1, 32, 32, 3) - expected_slice = np.array([1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0]) - assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 - assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 - - def test_ldm_text2img(self): - unet = self.dummy_cond_unet - scheduler = DDIMScheduler() - vae = self.dummy_vae - bert = self.dummy_text_encoder - tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") - - ldm = LDMTextToImagePipeline(vqvae=vae, bert=bert, tokenizer=tokenizer, unet=unet, scheduler=scheduler) - ldm.to(torch_device) - ldm.set_progress_bar_config(disable=None) - - prompt = "A painting of a squirrel eating a burger" + @parameterized.expand( + [ + [DDIMScheduler, DDIMPipeline, 32], + [DDPMScheduler, DDPMPipeline, 32], + [DDIMScheduler, DDIMPipeline, (32, 64)], + [DDPMScheduler, DDPMPipeline, (64, 32)], + ] + ) + def test_uncond_unet_components(self, scheduler_fn=DDPMScheduler, pipeline_fn=DDPMPipeline, sample_size=32): + unet = self.dummy_uncond_unet(sample_size) + scheduler = scheduler_fn() + pipeline = pipeline_fn(unet, scheduler).to(torch_device) - # Warmup pass when using mps (see #372) + # Device type MPS is not supported for torch.Generator() api. if torch_device == "mps": generator = torch.manual_seed(0) - _ = ldm( - [prompt], generator=generator, guidance_scale=6.0, num_inference_steps=1, output_type="numpy" - ).images - - generator = torch.manual_seed(0) - image = ldm( - [prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="numpy" - ).images + else: + generator = torch.Generator(device=torch_device).manual_seed(0) - generator = torch.manual_seed(0) - image_from_tuple = ldm( - [prompt], + out_image = pipeline( generator=generator, - guidance_scale=6.0, num_inference_steps=2, - output_type="numpy", - return_dict=False, - )[0] - - image_slice = image[0, -3:, -3:, -1] - image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] - - assert image.shape == (1, 64, 64, 3) - expected_slice = np.array([0.5074, 0.5026, 0.4998, 0.4056, 0.3523, 0.4649, 0.5289, 0.5299, 0.4897]) - assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 - assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 - - def test_stable_diffusion_ddim(self): - device = "cpu" # ensure determinism for the device-dependent torch.Generator - unet = self.dummy_cond_unet - scheduler = DDIMScheduler( - beta_start=0.00085, - beta_end=0.012, - beta_schedule="scaled_linear", - clip_sample=False, - set_alpha_to_one=False, - ) + output_type="np", + ).images + sample_size = (sample_size, sample_size) if isinstance(sample_size, int) else sample_size + assert out_image.shape == (1, *sample_size, 3) + def test_stable_diffusion_components(self): + """Test that components property works correctly""" + unet = self.dummy_cond_unet() + scheduler = PNDMScheduler(skip_prk_steps=True) vae = self.dummy_vae bert = self.dummy_text_encoder tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + image = self.dummy_image().cpu().permute(0, 2, 3, 1)[0] + init_image = Image.fromarray(np.uint8(image)).convert("RGB") + mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((32, 32)) + # make sure here that pndm scheduler skips prk - sd_pipe = StableDiffusionPipeline( + inpaint = StableDiffusionInpaintPipelineLegacy( unet=unet, scheduler=scheduler, vae=vae, text_encoder=bert, tokenizer=tokenizer, - safety_checker=self.dummy_safety_checker, + safety_checker=None, feature_extractor=self.dummy_extractor, - ) - sd_pipe = sd_pipe.to(device) - sd_pipe.set_progress_bar_config(disable=None) + ).to(torch_device) + img2img = StableDiffusionImg2ImgPipeline(**inpaint.components).to(torch_device) + text2img = StableDiffusionPipeline(**inpaint.components).to(torch_device) prompt = "A painting of a squirrel eating a burger" - generator = torch.Generator(device=device).manual_seed(0) - output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np") - image = output.images + # Device type MPS is not supported for torch.Generator() api. + if torch_device == "mps": + generator = torch.manual_seed(0) + else: + generator = torch.Generator(device=torch_device).manual_seed(0) - generator = torch.Generator(device=device).manual_seed(0) - image_from_tuple = sd_pipe( + image_inpaint = inpaint( [prompt], generator=generator, - guidance_scale=6.0, num_inference_steps=2, output_type="np", - return_dict=False, - )[0] - - image_slice = image[0, -3:, -3:, -1] - image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] - - assert image.shape == (1, 128, 128, 3) - expected_slice = np.array([0.5112, 0.4692, 0.4715, 0.5206, 0.4894, 0.5114, 0.5096, 0.4932, 0.4755]) - - assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 - assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 - - def test_stable_diffusion_ddim_factor_8(self): - device = "cpu" # ensure determinism for the device-dependent torch.Generator - unet = self.dummy_cond_unet - scheduler = DDIMScheduler( - beta_start=0.00085, - beta_end=0.012, - beta_schedule="scaled_linear", - clip_sample=False, - set_alpha_to_one=False, - ) - - vae = self.dummy_vae - bert = self.dummy_text_encoder - tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") - - # make sure here that pndm scheduler skips prk - sd_pipe = StableDiffusionPipeline( - unet=unet, - scheduler=scheduler, - vae=vae, - text_encoder=bert, - tokenizer=tokenizer, - safety_checker=self.dummy_safety_checker, - feature_extractor=self.dummy_extractor, - ) - sd_pipe = sd_pipe.to(device) - sd_pipe.set_progress_bar_config(disable=None) - - prompt = "A painting of a squirrel eating a burger" - - generator = torch.Generator(device=device).manual_seed(0) - output = sd_pipe( + init_image=init_image, + mask_image=mask_image, + ).images + image_img2img = img2img( [prompt], generator=generator, - guidance_scale=6.0, - height=536, - width=536, num_inference_steps=2, output_type="np", - ) - image = output.images - - image_slice = image[0, -3:, -3:, -1] - - assert image.shape == (1, 134, 134, 3) - expected_slice = np.array([0.7834, 0.5488, 0.5781, 0.46, 0.3609, 0.5369, 0.542, 0.4855, 0.5557]) + init_image=init_image, + ).images + image_text2img = text2img( + [prompt], + generator=generator, + num_inference_steps=2, + output_type="np", + ).images - assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert image_inpaint.shape == (1, 32, 32, 3) + assert image_img2img.shape == (1, 32, 32, 3) + assert image_text2img.shape == (1, 64, 64, 3) - def test_stable_diffusion_pndm(self): - device = "cpu" # ensure determinism for the device-dependent torch.Generator - unet = self.dummy_cond_unet + def test_set_scheduler(self): + unet = self.dummy_cond_unet() scheduler = PNDMScheduler(skip_prk_steps=True) vae = self.dummy_vae bert = self.dummy_text_encoder tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") - # make sure here that pndm scheduler skips prk - sd_pipe = StableDiffusionPipeline( + sd = StableDiffusionPipeline( unet=unet, scheduler=scheduler, vae=vae, text_encoder=bert, tokenizer=tokenizer, - safety_checker=self.dummy_safety_checker, + safety_checker=None, feature_extractor=self.dummy_extractor, ) - sd_pipe = sd_pipe.to(device) - sd_pipe.set_progress_bar_config(disable=None) - - prompt = "A painting of a squirrel eating a burger" - generator = torch.Generator(device=device).manual_seed(0) - output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np") - - image = output.images - - generator = torch.Generator(device=device).manual_seed(0) - image_from_tuple = sd_pipe( - [prompt], - generator=generator, - guidance_scale=6.0, - num_inference_steps=2, - output_type="np", - return_dict=False, - )[0] - - image_slice = image[0, -3:, -3:, -1] - image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] - - assert image.shape == (1, 128, 128, 3) - expected_slice = np.array([0.4937, 0.4649, 0.4716, 0.5145, 0.4889, 0.513, 0.513, 0.4905, 0.4738]) - assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 - assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 - - def test_from_pretrained_error_message_uninstalled_packages(self): - # TODO(Patrick, Pedro) - need better test here for the future - pipe = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-lms-pipe") - assert isinstance(pipe, StableDiffusionPipeline) - assert isinstance(pipe.scheduler, LMSDiscreteScheduler) - - def test_stable_diffusion_no_safety_checker(self): - pipe = StableDiffusionPipeline.from_pretrained( - "hf-internal-testing/tiny-stable-diffusion-lms-pipe", safety_checker=None - ) - assert isinstance(pipe, StableDiffusionPipeline) - assert isinstance(pipe.scheduler, LMSDiscreteScheduler) - assert pipe.safety_checker is None - image = pipe("example prompt", num_inference_steps=2).images[0] - assert image is not None - - def test_stable_diffusion_k_lms(self): - device = "cpu" # ensure determinism for the device-dependent torch.Generator - unet = self.dummy_cond_unet - scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear") + sd.scheduler = DDIMScheduler.from_config(sd.scheduler.config) + assert isinstance(sd.scheduler, DDIMScheduler) + sd.scheduler = DDPMScheduler.from_config(sd.scheduler.config) + assert isinstance(sd.scheduler, DDPMScheduler) + sd.scheduler = PNDMScheduler.from_config(sd.scheduler.config) + assert isinstance(sd.scheduler, PNDMScheduler) + sd.scheduler = LMSDiscreteScheduler.from_config(sd.scheduler.config) + assert isinstance(sd.scheduler, LMSDiscreteScheduler) + sd.scheduler = EulerDiscreteScheduler.from_config(sd.scheduler.config) + assert isinstance(sd.scheduler, EulerDiscreteScheduler) + sd.scheduler = EulerAncestralDiscreteScheduler.from_config(sd.scheduler.config) + assert isinstance(sd.scheduler, EulerAncestralDiscreteScheduler) + sd.scheduler = DPMSolverMultistepScheduler.from_config(sd.scheduler.config) + assert isinstance(sd.scheduler, DPMSolverMultistepScheduler) + + def test_set_scheduler_consistency(self): + unet = self.dummy_cond_unet() + pndm = PNDMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler") + ddim = DDIMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler") vae = self.dummy_vae bert = self.dummy_text_encoder tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") - # make sure here that pndm scheduler skips prk - sd_pipe = StableDiffusionPipeline( + sd = StableDiffusionPipeline( unet=unet, - scheduler=scheduler, + scheduler=pndm, vae=vae, text_encoder=bert, tokenizer=tokenizer, - safety_checker=self.dummy_safety_checker, + safety_checker=None, feature_extractor=self.dummy_extractor, ) - sd_pipe = sd_pipe.to(device) - sd_pipe.set_progress_bar_config(disable=None) - - prompt = "A painting of a squirrel eating a burger" - generator = torch.Generator(device=device).manual_seed(0) - output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np") - - image = output.images - - generator = torch.Generator(device=device).manual_seed(0) - image_from_tuple = sd_pipe( - [prompt], - generator=generator, - guidance_scale=6.0, - num_inference_steps=2, - output_type="np", - return_dict=False, - )[0] - - image_slice = image[0, -3:, -3:, -1] - image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] - assert image.shape == (1, 128, 128, 3) - expected_slice = np.array([0.5067, 0.4689, 0.4614, 0.5233, 0.4903, 0.5112, 0.524, 0.5069, 0.4785]) - assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 - assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + pndm_config = sd.scheduler.config + sd.scheduler = DDPMScheduler.from_config(pndm_config) + sd.scheduler = PNDMScheduler.from_config(sd.scheduler.config) + pndm_config_2 = sd.scheduler.config + pndm_config_2 = {k: v for k, v in pndm_config_2.items() if k in pndm_config} - def test_stable_diffusion_attention_chunk(self): - device = "cpu" # ensure determinism for the device-dependent torch.Generator - unet = self.dummy_cond_unet - scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear") - vae = self.dummy_vae - bert = self.dummy_text_encoder - tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + assert dict(pndm_config) == dict(pndm_config_2) - # make sure here that pndm scheduler skips prk - sd_pipe = StableDiffusionPipeline( + sd = StableDiffusionPipeline( unet=unet, - scheduler=scheduler, + scheduler=ddim, vae=vae, text_encoder=bert, tokenizer=tokenizer, - safety_checker=self.dummy_safety_checker, + safety_checker=None, feature_extractor=self.dummy_extractor, ) - sd_pipe = sd_pipe.to(device) - sd_pipe.set_progress_bar_config(disable=None) - - prompt = "A painting of a squirrel eating a burger" - generator = torch.Generator(device=device).manual_seed(0) - output_1 = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np") - # make sure chunking the attention yields the same result - sd_pipe.enable_attention_slicing(slice_size=1) - generator = torch.Generator(device=device).manual_seed(0) - output_2 = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np") + ddim_config = sd.scheduler.config + sd.scheduler = LMSDiscreteScheduler.from_config(ddim_config) + sd.scheduler = DDIMScheduler.from_config(sd.scheduler.config) + ddim_config_2 = sd.scheduler.config + ddim_config_2 = {k: v for k, v in ddim_config_2.items() if k in ddim_config} - assert np.abs(output_2.images.flatten() - output_1.images.flatten()).max() < 1e-4 + assert dict(ddim_config) == dict(ddim_config_2) - def test_stable_diffusion_negative_prompt(self): - device = "cpu" # ensure determinism for the device-dependent torch.Generator - unet = self.dummy_cond_unet - scheduler = PNDMScheduler(skip_prk_steps=True) + def test_optional_components(self): + unet = self.dummy_cond_unet() + pndm = PNDMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler") vae = self.dummy_vae bert = self.dummy_text_encoder tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") - # make sure here that pndm scheduler skips prk - sd_pipe = StableDiffusionPipeline( + orig_sd = StableDiffusionPipeline( unet=unet, - scheduler=scheduler, + scheduler=pndm, vae=vae, text_encoder=bert, tokenizer=tokenizer, - safety_checker=self.dummy_safety_checker, + safety_checker=unet, feature_extractor=self.dummy_extractor, ) - sd_pipe = sd_pipe.to(device) - sd_pipe.set_progress_bar_config(disable=None) - - prompt = "A painting of a squirrel eating a burger" - negative_prompt = "french fries" - generator = torch.Generator(device=device).manual_seed(0) - output = sd_pipe( - prompt, - negative_prompt=negative_prompt, - generator=generator, - guidance_scale=6.0, - num_inference_steps=2, - output_type="np", - ) + sd = orig_sd - image = output.images - image_slice = image[0, -3:, -3:, -1] + assert sd.config.requires_safety_checker is True - assert image.shape == (1, 128, 128, 3) - expected_slice = np.array([0.4851, 0.4617, 0.4765, 0.5127, 0.4845, 0.5153, 0.5141, 0.4886, 0.4719]) - assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 - - def test_score_sde_ve_pipeline(self): - unet = self.dummy_uncond_unet - scheduler = ScoreSdeVeScheduler() - - sde_ve = ScoreSdeVePipeline(unet=unet, scheduler=scheduler) - sde_ve.to(torch_device) - sde_ve.set_progress_bar_config(disable=None) + with tempfile.TemporaryDirectory() as tmpdirname: + sd.save_pretrained(tmpdirname) - generator = torch.manual_seed(0) - image = sde_ve(num_inference_steps=2, output_type="numpy", generator=generator).images + # Test that passing None works + sd = StableDiffusionPipeline.from_pretrained( + tmpdirname, feature_extractor=None, safety_checker=None, requires_safety_checker=False + ) - generator = torch.manual_seed(0) - image_from_tuple = sde_ve(num_inference_steps=2, output_type="numpy", generator=generator, return_dict=False)[ - 0 - ] + assert sd.config.requires_safety_checker is False + assert sd.config.safety_checker == (None, None) + assert sd.config.feature_extractor == (None, None) - image_slice = image[0, -3:, -3:, -1] - image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + with tempfile.TemporaryDirectory() as tmpdirname: + sd.save_pretrained(tmpdirname) - assert image.shape == (1, 32, 32, 3) - expected_slice = np.array([0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0]) - assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 - assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + # Test that loading previous None works + sd = StableDiffusionPipeline.from_pretrained(tmpdirname) - def test_ldm_uncond(self): - unet = self.dummy_uncond_unet - scheduler = DDIMScheduler() - vae = self.dummy_vq_model + assert sd.config.requires_safety_checker is False + assert sd.config.safety_checker == (None, None) + assert sd.config.feature_extractor == (None, None) - ldm = LDMPipeline(unet=unet, vqvae=vae, scheduler=scheduler) - ldm.to(torch_device) - ldm.set_progress_bar_config(disable=None) + orig_sd.save_pretrained(tmpdirname) - # Warmup pass when using mps (see #372) - if torch_device == "mps": - generator = torch.manual_seed(0) - _ = ldm(generator=generator, num_inference_steps=1, output_type="numpy").images + # Test that loading without any directory works + shutil.rmtree(os.path.join(tmpdirname, "safety_checker")) + with open(os.path.join(tmpdirname, sd.config_name)) as f: + config = json.load(f) + config["safety_checker"] = [None, None] + with open(os.path.join(tmpdirname, sd.config_name), "w") as f: + json.dump(config, f) - generator = torch.manual_seed(0) - image = ldm(generator=generator, num_inference_steps=2, output_type="numpy").images + sd = StableDiffusionPipeline.from_pretrained(tmpdirname, requires_safety_checker=False) + sd.save_pretrained(tmpdirname) + sd = StableDiffusionPipeline.from_pretrained(tmpdirname) - generator = torch.manual_seed(0) - image_from_tuple = ldm(generator=generator, num_inference_steps=2, output_type="numpy", return_dict=False)[0] + assert sd.config.requires_safety_checker is False + assert sd.config.safety_checker == (None, None) + assert sd.config.feature_extractor == (None, None) - image_slice = image[0, -3:, -3:, -1] - image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + # Test that loading from deleted model index works + with open(os.path.join(tmpdirname, sd.config_name)) as f: + config = json.load(f) + del config["safety_checker"] + del config["feature_extractor"] + with open(os.path.join(tmpdirname, sd.config_name), "w") as f: + json.dump(config, f) - assert image.shape == (1, 64, 64, 3) - expected_slice = np.array([0.8512, 0.818, 0.6411, 0.6808, 0.4465, 0.5618, 0.46, 0.6231, 0.5172]) - assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 - assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + sd = StableDiffusionPipeline.from_pretrained(tmpdirname) - def test_karras_ve_pipeline(self): - unet = self.dummy_uncond_unet - scheduler = KarrasVeScheduler() + assert sd.config.requires_safety_checker is False + assert sd.config.safety_checker == (None, None) + assert sd.config.feature_extractor == (None, None) - pipe = KarrasVePipeline(unet=unet, scheduler=scheduler) - pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) + with tempfile.TemporaryDirectory() as tmpdirname: + sd.save_pretrained(tmpdirname) - generator = torch.manual_seed(0) - image = pipe(num_inference_steps=2, generator=generator, output_type="numpy").images + # Test that partially loading works + sd = StableDiffusionPipeline.from_pretrained(tmpdirname, feature_extractor=self.dummy_extractor) - generator = torch.manual_seed(0) - image_from_tuple = pipe(num_inference_steps=2, generator=generator, output_type="numpy", return_dict=False)[0] + assert sd.config.requires_safety_checker is False + assert sd.config.safety_checker == (None, None) + assert sd.config.feature_extractor != (None, None) - image_slice = image[0, -3:, -3:, -1] - image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + # Test that partially loading works + sd = StableDiffusionPipeline.from_pretrained( + tmpdirname, + feature_extractor=self.dummy_extractor, + safety_checker=unet, + requires_safety_checker=[True, True], + ) - assert image.shape == (1, 32, 32, 3) - expected_slice = np.array([0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0]) - assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 - assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + assert sd.config.requires_safety_checker == [True, True] + assert sd.config.safety_checker != (None, None) + assert sd.config.feature_extractor != (None, None) - def test_stable_diffusion_img2img(self): - device = "cpu" # ensure determinism for the device-dependent torch.Generator - unet = self.dummy_cond_unet - scheduler = PNDMScheduler(skip_prk_steps=True) - vae = self.dummy_vae - bert = self.dummy_text_encoder - tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + with tempfile.TemporaryDirectory() as tmpdirname: + sd.save_pretrained(tmpdirname) + sd = StableDiffusionPipeline.from_pretrained(tmpdirname, feature_extractor=self.dummy_extractor) - init_image = self.dummy_image.to(device) + assert sd.config.requires_safety_checker == [True, True] + assert sd.config.safety_checker != (None, None) + assert sd.config.feature_extractor != (None, None) - # make sure here that pndm scheduler skips prk - sd_pipe = StableDiffusionImg2ImgPipeline( - unet=unet, - scheduler=scheduler, - vae=vae, - text_encoder=bert, - tokenizer=tokenizer, - safety_checker=self.dummy_safety_checker, - feature_extractor=self.dummy_extractor, - ) - sd_pipe = sd_pipe.to(device) - sd_pipe.set_progress_bar_config(disable=None) - prompt = "A painting of a squirrel eating a burger" - generator = torch.Generator(device=device).manual_seed(0) - output = sd_pipe( - [prompt], - generator=generator, - guidance_scale=6.0, - num_inference_steps=2, - output_type="np", - init_image=init_image, - ) +@slow +class PipelineSlowTests(unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() - image = output.images + def test_smart_download(self): + model_id = "hf-internal-testing/unet-pipeline-dummy" + with tempfile.TemporaryDirectory() as tmpdirname: + _ = DiffusionPipeline.from_pretrained(model_id, cache_dir=tmpdirname, force_download=True) + local_repo_name = "--".join(["models"] + model_id.split("/")) + snapshot_dir = os.path.join(tmpdirname, local_repo_name, "snapshots") + snapshot_dir = os.path.join(snapshot_dir, os.listdir(snapshot_dir)[0]) - generator = torch.Generator(device=device).manual_seed(0) - image_from_tuple = sd_pipe( - [prompt], - generator=generator, - guidance_scale=6.0, - num_inference_steps=2, - output_type="np", - init_image=init_image, - return_dict=False, - )[0] + # inspect all downloaded files to make sure that everything is included + assert os.path.isfile(os.path.join(snapshot_dir, DiffusionPipeline.config_name)) + assert os.path.isfile(os.path.join(snapshot_dir, CONFIG_NAME)) + assert os.path.isfile(os.path.join(snapshot_dir, SCHEDULER_CONFIG_NAME)) + assert os.path.isfile(os.path.join(snapshot_dir, WEIGHTS_NAME)) + assert os.path.isfile(os.path.join(snapshot_dir, "scheduler", SCHEDULER_CONFIG_NAME)) + assert os.path.isfile(os.path.join(snapshot_dir, "unet", WEIGHTS_NAME)) + assert os.path.isfile(os.path.join(snapshot_dir, "unet", WEIGHTS_NAME)) + # let's make sure the super large numpy file: + # https://huggingface.co/hf-internal-testing/unet-pipeline-dummy/blob/main/big_array.npy + # is not downloaded, but all the expected ones + assert not os.path.isfile(os.path.join(snapshot_dir, "big_array.npy")) - image_slice = image[0, -3:, -3:, -1] - image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + def test_warning_unused_kwargs(self): + model_id = "hf-internal-testing/unet-pipeline-dummy" + logger = logging.get_logger("diffusers.pipeline_utils") + with tempfile.TemporaryDirectory() as tmpdirname: + with CaptureLogger(logger) as cap_logger: + DiffusionPipeline.from_pretrained( + model_id, + not_used=True, + cache_dir=tmpdirname, + force_download=True, + ) - assert image.shape == (1, 32, 32, 3) - expected_slice = np.array([0.4492, 0.3865, 0.4222, 0.5854, 0.5139, 0.4379, 0.4193, 0.48, 0.4218]) - assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 - assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 - - def test_stable_diffusion_img2img_negative_prompt(self): - device = "cpu" # ensure determinism for the device-dependent torch.Generator - unet = self.dummy_cond_unet - scheduler = PNDMScheduler(skip_prk_steps=True) - vae = self.dummy_vae - bert = self.dummy_text_encoder - tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") - - init_image = self.dummy_image.to(device) - - # make sure here that pndm scheduler skips prk - sd_pipe = StableDiffusionImg2ImgPipeline( - unet=unet, - scheduler=scheduler, - vae=vae, - text_encoder=bert, - tokenizer=tokenizer, - safety_checker=self.dummy_safety_checker, - feature_extractor=self.dummy_extractor, - ) - sd_pipe = sd_pipe.to(device) - sd_pipe.set_progress_bar_config(disable=None) - - prompt = "A painting of a squirrel eating a burger" - negative_prompt = "french fries" - generator = torch.Generator(device=device).manual_seed(0) - output = sd_pipe( - prompt, - negative_prompt=negative_prompt, - generator=generator, - guidance_scale=6.0, - num_inference_steps=2, - output_type="np", - init_image=init_image, - ) - image = output.images - image_slice = image[0, -3:, -3:, -1] - - assert image.shape == (1, 32, 32, 3) - expected_slice = np.array([0.4065, 0.3783, 0.4050, 0.5266, 0.4781, 0.4252, 0.4203, 0.4692, 0.4365]) - assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 - - def test_stable_diffusion_img2img_multiple_init_images(self): - device = "cpu" # ensure determinism for the device-dependent torch.Generator - unet = self.dummy_cond_unet - scheduler = PNDMScheduler(skip_prk_steps=True) - vae = self.dummy_vae - bert = self.dummy_text_encoder - tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") - - init_image = self.dummy_image.to(device).repeat(2, 1, 1, 1) - - # make sure here that pndm scheduler skips prk - sd_pipe = StableDiffusionImg2ImgPipeline( - unet=unet, - scheduler=scheduler, - vae=vae, - text_encoder=bert, - tokenizer=tokenizer, - safety_checker=self.dummy_safety_checker, - feature_extractor=self.dummy_extractor, - ) - sd_pipe = sd_pipe.to(device) - sd_pipe.set_progress_bar_config(disable=None) - - prompt = 2 * ["A painting of a squirrel eating a burger"] - generator = torch.Generator(device=device).manual_seed(0) - output = sd_pipe( - prompt, - generator=generator, - guidance_scale=6.0, - num_inference_steps=2, - output_type="np", - init_image=init_image, - ) - - image = output.images - - image_slice = image[-1, -3:, -3:, -1] - - assert image.shape == (2, 32, 32, 3) - expected_slice = np.array([0.5144, 0.4447, 0.4735, 0.6676, 0.5526, 0.5454, 0.645, 0.5149, 0.4689]) - assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 - - def test_stable_diffusion_img2img_k_lms(self): - device = "cpu" # ensure determinism for the device-dependent torch.Generator - unet = self.dummy_cond_unet - scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear") - - vae = self.dummy_vae - bert = self.dummy_text_encoder - tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") - - init_image = self.dummy_image.to(device) - - # make sure here that pndm scheduler skips prk - sd_pipe = StableDiffusionImg2ImgPipeline( - unet=unet, - scheduler=scheduler, - vae=vae, - text_encoder=bert, - tokenizer=tokenizer, - safety_checker=self.dummy_safety_checker, - feature_extractor=self.dummy_extractor, - ) - sd_pipe = sd_pipe.to(device) - sd_pipe.set_progress_bar_config(disable=None) - - prompt = "A painting of a squirrel eating a burger" - generator = torch.Generator(device=device).manual_seed(0) - output = sd_pipe( - [prompt], - generator=generator, - guidance_scale=6.0, - num_inference_steps=2, - output_type="np", - init_image=init_image, - ) - image = output.images - - generator = torch.Generator(device=device).manual_seed(0) - output = sd_pipe( - [prompt], - generator=generator, - guidance_scale=6.0, - num_inference_steps=2, - output_type="np", - init_image=init_image, - return_dict=False, - ) - image_from_tuple = output[0] - - image_slice = image[0, -3:, -3:, -1] - image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] - - assert image.shape == (1, 32, 32, 3) - expected_slice = np.array([0.4367, 0.4986, 0.4372, 0.6706, 0.5665, 0.444, 0.5864, 0.6019, 0.5203]) - assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 - assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 - - def test_stable_diffusion_inpaint(self): - device = "cpu" # ensure determinism for the device-dependent torch.Generator - unet = self.dummy_cond_unet - scheduler = PNDMScheduler(skip_prk_steps=True) - vae = self.dummy_vae - bert = self.dummy_text_encoder - tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") - - image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0] - init_image = Image.fromarray(np.uint8(image)).convert("RGB") - mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128)) - - # make sure here that pndm scheduler skips prk - sd_pipe = StableDiffusionInpaintPipeline( - unet=unet, - scheduler=scheduler, - vae=vae, - text_encoder=bert, - tokenizer=tokenizer, - safety_checker=self.dummy_safety_checker, - feature_extractor=self.dummy_extractor, - ) - sd_pipe = sd_pipe.to(device) - sd_pipe.set_progress_bar_config(disable=None) - - prompt = "A painting of a squirrel eating a burger" - generator = torch.Generator(device=device).manual_seed(0) - output = sd_pipe( - [prompt], - generator=generator, - guidance_scale=6.0, - num_inference_steps=2, - output_type="np", - init_image=init_image, - mask_image=mask_image, - ) - - image = output.images - - generator = torch.Generator(device=device).manual_seed(0) - image_from_tuple = sd_pipe( - [prompt], - generator=generator, - guidance_scale=6.0, - num_inference_steps=2, - output_type="np", - init_image=init_image, - mask_image=mask_image, - return_dict=False, - )[0] - - image_slice = image[0, -3:, -3:, -1] - image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] - - assert image.shape == (1, 32, 32, 3) - expected_slice = np.array([0.4731, 0.5346, 0.4531, 0.6251, 0.5446, 0.4057, 0.5527, 0.5896, 0.5153]) - assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 - assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 - - def test_stable_diffusion_inpaint_negative_prompt(self): - device = "cpu" # ensure determinism for the device-dependent torch.Generator - unet = self.dummy_cond_unet - scheduler = PNDMScheduler(skip_prk_steps=True) - vae = self.dummy_vae - bert = self.dummy_text_encoder - tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") - - image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0] - init_image = Image.fromarray(np.uint8(image)).convert("RGB") - mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128)) - - # make sure here that pndm scheduler skips prk - sd_pipe = StableDiffusionInpaintPipeline( - unet=unet, - scheduler=scheduler, - vae=vae, - text_encoder=bert, - tokenizer=tokenizer, - safety_checker=self.dummy_safety_checker, - feature_extractor=self.dummy_extractor, - ) - sd_pipe = sd_pipe.to(device) - sd_pipe.set_progress_bar_config(disable=None) - - prompt = "A painting of a squirrel eating a burger" - negative_prompt = "french fries" - generator = torch.Generator(device=device).manual_seed(0) - output = sd_pipe( - prompt, - negative_prompt=negative_prompt, - generator=generator, - guidance_scale=6.0, - num_inference_steps=2, - output_type="np", - init_image=init_image, - mask_image=mask_image, - ) - - image = output.images - image_slice = image[0, -3:, -3:, -1] - - assert image.shape == (1, 32, 32, 3) - expected_slice = np.array([0.4765, 0.5339, 0.4541, 0.6240, 0.5439, 0.4055, 0.5503, 0.5891, 0.5150]) - assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 - - def test_stable_diffusion_num_images_per_prompt(self): - device = "cpu" # ensure determinism for the device-dependent torch.Generator - unet = self.dummy_cond_unet - scheduler = PNDMScheduler(skip_prk_steps=True) - vae = self.dummy_vae - bert = self.dummy_text_encoder - tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") - - # make sure here that pndm scheduler skips prk - sd_pipe = StableDiffusionPipeline( - unet=unet, - scheduler=scheduler, - vae=vae, - text_encoder=bert, - tokenizer=tokenizer, - safety_checker=self.dummy_safety_checker, - feature_extractor=self.dummy_extractor, - ) - sd_pipe = sd_pipe.to(device) - sd_pipe.set_progress_bar_config(disable=None) - - prompt = "A painting of a squirrel eating a burger" - - # test num_images_per_prompt=1 (default) - images = sd_pipe(prompt, num_inference_steps=2, output_type="np").images - - assert images.shape == (1, 128, 128, 3) - - # test num_images_per_prompt=1 (default) for batch of prompts - batch_size = 2 - images = sd_pipe([prompt] * batch_size, num_inference_steps=2, output_type="np").images - - assert images.shape == (batch_size, 128, 128, 3) - - # test num_images_per_prompt for single prompt - num_images_per_prompt = 2 - images = sd_pipe( - prompt, num_inference_steps=2, output_type="np", num_images_per_prompt=num_images_per_prompt - ).images - - assert images.shape == (num_images_per_prompt, 128, 128, 3) - - # test num_images_per_prompt for batch of prompts - batch_size = 2 - images = sd_pipe( - [prompt] * batch_size, num_inference_steps=2, output_type="np", num_images_per_prompt=num_images_per_prompt - ).images - - assert images.shape == (batch_size * num_images_per_prompt, 128, 128, 3) - - def test_stable_diffusion_img2img_num_images_per_prompt(self): - device = "cpu" - unet = self.dummy_cond_unet - scheduler = PNDMScheduler(skip_prk_steps=True) - vae = self.dummy_vae - bert = self.dummy_text_encoder - tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") - - init_image = self.dummy_image.to(device) - - # make sure here that pndm scheduler skips prk - sd_pipe = StableDiffusionImg2ImgPipeline( - unet=unet, - scheduler=scheduler, - vae=vae, - text_encoder=bert, - tokenizer=tokenizer, - safety_checker=self.dummy_safety_checker, - feature_extractor=self.dummy_extractor, - ) - sd_pipe = sd_pipe.to(device) - sd_pipe.set_progress_bar_config(disable=None) - - prompt = "A painting of a squirrel eating a burger" - - # test num_images_per_prompt=1 (default) - images = sd_pipe( - prompt, - num_inference_steps=2, - output_type="np", - init_image=init_image, - ).images - - assert images.shape == (1, 32, 32, 3) - - # test num_images_per_prompt=1 (default) for batch of prompts - batch_size = 2 - images = sd_pipe( - [prompt] * batch_size, - num_inference_steps=2, - output_type="np", - init_image=init_image, - ).images - - assert images.shape == (batch_size, 32, 32, 3) - - # test num_images_per_prompt for single prompt - num_images_per_prompt = 2 - images = sd_pipe( - prompt, - num_inference_steps=2, - output_type="np", - init_image=init_image, - num_images_per_prompt=num_images_per_prompt, - ).images - - assert images.shape == (num_images_per_prompt, 32, 32, 3) - - # test num_images_per_prompt for batch of prompts - batch_size = 2 - images = sd_pipe( - [prompt] * batch_size, - num_inference_steps=2, - output_type="np", - init_image=init_image, - num_images_per_prompt=num_images_per_prompt, - ).images - - assert images.shape == (batch_size * num_images_per_prompt, 32, 32, 3) - - def test_stable_diffusion_inpaint_num_images_per_prompt(self): - device = "cpu" - unet = self.dummy_cond_unet - scheduler = PNDMScheduler(skip_prk_steps=True) - vae = self.dummy_vae - bert = self.dummy_text_encoder - tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") - - image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0] - init_image = Image.fromarray(np.uint8(image)).convert("RGB") - mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128)) - - # make sure here that pndm scheduler skips prk - sd_pipe = StableDiffusionInpaintPipeline( - unet=unet, - scheduler=scheduler, - vae=vae, - text_encoder=bert, - tokenizer=tokenizer, - safety_checker=self.dummy_safety_checker, - feature_extractor=self.dummy_extractor, - ) - sd_pipe = sd_pipe.to(device) - sd_pipe.set_progress_bar_config(disable=None) - - prompt = "A painting of a squirrel eating a burger" - - # test num_images_per_prompt=1 (default) - images = sd_pipe( - prompt, - num_inference_steps=2, - output_type="np", - init_image=init_image, - mask_image=mask_image, - ).images - - assert images.shape == (1, 32, 32, 3) - - # test num_images_per_prompt=1 (default) for batch of prompts - batch_size = 2 - images = sd_pipe( - [prompt] * batch_size, - num_inference_steps=2, - output_type="np", - init_image=init_image, - mask_image=mask_image, - ).images - - assert images.shape == (batch_size, 32, 32, 3) - - # test num_images_per_prompt for single prompt - num_images_per_prompt = 2 - images = sd_pipe( - prompt, - num_inference_steps=2, - output_type="np", - init_image=init_image, - mask_image=mask_image, - num_images_per_prompt=num_images_per_prompt, - ).images - - assert images.shape == (num_images_per_prompt, 32, 32, 3) - - # test num_images_per_prompt for batch of prompts - batch_size = 2 - images = sd_pipe( - [prompt] * batch_size, - num_inference_steps=2, - output_type="np", - init_image=init_image, - mask_image=mask_image, - num_images_per_prompt=num_images_per_prompt, - ).images - - assert images.shape == (batch_size * num_images_per_prompt, 32, 32, 3) - - @unittest.skipIf(torch_device != "cuda", "This test requires a GPU") - def test_stable_diffusion_fp16(self): - """Test that stable diffusion works with fp16""" - unet = self.dummy_cond_unet - scheduler = PNDMScheduler(skip_prk_steps=True) - vae = self.dummy_vae - bert = self.dummy_text_encoder - tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") - - # put models in fp16 - unet = unet.half() - vae = vae.half() - bert = bert.half() - - # make sure here that pndm scheduler skips prk - sd_pipe = StableDiffusionPipeline( - unet=unet, - scheduler=scheduler, - vae=vae, - text_encoder=bert, - tokenizer=tokenizer, - safety_checker=self.dummy_safety_checker, - feature_extractor=self.dummy_extractor, - ) - sd_pipe = sd_pipe.to(torch_device) - sd_pipe.set_progress_bar_config(disable=None) - - prompt = "A painting of a squirrel eating a burger" - generator = torch.Generator(device=torch_device).manual_seed(0) - image = sd_pipe([prompt], generator=generator, num_inference_steps=2, output_type="np").images - - assert image.shape == (1, 128, 128, 3) - - @unittest.skipIf(torch_device != "cuda", "This test requires a GPU") - def test_stable_diffusion_img2img_fp16(self): - """Test that stable diffusion img2img works with fp16""" - unet = self.dummy_cond_unet - scheduler = PNDMScheduler(skip_prk_steps=True) - vae = self.dummy_vae - bert = self.dummy_text_encoder - tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") - - init_image = self.dummy_image.to(torch_device) - - # put models in fp16 - unet = unet.half() - vae = vae.half() - bert = bert.half() - - # make sure here that pndm scheduler skips prk - sd_pipe = StableDiffusionImg2ImgPipeline( - unet=unet, - scheduler=scheduler, - vae=vae, - text_encoder=bert, - tokenizer=tokenizer, - safety_checker=self.dummy_safety_checker, - feature_extractor=self.dummy_extractor, - ) - sd_pipe = sd_pipe.to(torch_device) - sd_pipe.set_progress_bar_config(disable=None) - - prompt = "A painting of a squirrel eating a burger" - generator = torch.Generator(device=torch_device).manual_seed(0) - image = sd_pipe( - [prompt], - generator=generator, - num_inference_steps=2, - output_type="np", - init_image=init_image, - ).images - - assert image.shape == (1, 32, 32, 3) - - @unittest.skipIf(torch_device != "cuda", "This test requires a GPU") - def test_stable_diffusion_inpaint_fp16(self): - """Test that stable diffusion inpaint works with fp16""" - unet = self.dummy_cond_unet - scheduler = PNDMScheduler(skip_prk_steps=True) - vae = self.dummy_vae - bert = self.dummy_text_encoder - tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") - - image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0] - init_image = Image.fromarray(np.uint8(image)).convert("RGB") - mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128)) - - # put models in fp16 - unet = unet.half() - vae = vae.half() - bert = bert.half() - - # make sure here that pndm scheduler skips prk - sd_pipe = StableDiffusionInpaintPipeline( - unet=unet, - scheduler=scheduler, - vae=vae, - text_encoder=bert, - tokenizer=tokenizer, - safety_checker=self.dummy_safety_checker, - feature_extractor=self.dummy_extractor, - ) - sd_pipe = sd_pipe.to(torch_device) - sd_pipe.set_progress_bar_config(disable=None) - - prompt = "A painting of a squirrel eating a burger" - generator = torch.Generator(device=torch_device).manual_seed(0) - image = sd_pipe( - [prompt], - generator=generator, - num_inference_steps=2, - output_type="np", - init_image=init_image, - mask_image=mask_image, - ).images - - assert image.shape == (1, 32, 32, 3) - - -class PipelineTesterMixin(unittest.TestCase): - def tearDown(self): - # clean up the VRAM after each test - super().tearDown() - gc.collect() - torch.cuda.empty_cache() - - def test_smart_download(self): - model_id = "hf-internal-testing/unet-pipeline-dummy" - with tempfile.TemporaryDirectory() as tmpdirname: - _ = DiffusionPipeline.from_pretrained(model_id, cache_dir=tmpdirname, force_download=True) - local_repo_name = "--".join(["models"] + model_id.split("/")) - snapshot_dir = os.path.join(tmpdirname, local_repo_name, "snapshots") - snapshot_dir = os.path.join(snapshot_dir, os.listdir(snapshot_dir)[0]) - - # inspect all downloaded files to make sure that everything is included - assert os.path.isfile(os.path.join(snapshot_dir, DiffusionPipeline.config_name)) - assert os.path.isfile(os.path.join(snapshot_dir, CONFIG_NAME)) - assert os.path.isfile(os.path.join(snapshot_dir, SCHEDULER_CONFIG_NAME)) - assert os.path.isfile(os.path.join(snapshot_dir, WEIGHTS_NAME)) - assert os.path.isfile(os.path.join(snapshot_dir, "scheduler", SCHEDULER_CONFIG_NAME)) - assert os.path.isfile(os.path.join(snapshot_dir, "unet", WEIGHTS_NAME)) - assert os.path.isfile(os.path.join(snapshot_dir, "unet", WEIGHTS_NAME)) - # let's make sure the super large numpy file: - # https://huggingface.co/hf-internal-testing/unet-pipeline-dummy/blob/main/big_array.npy - # is not downloaded, but all the expected ones - assert not os.path.isfile(os.path.join(snapshot_dir, "big_array.npy")) - - @property - def dummy_safety_checker(self): - def check(images, *args, **kwargs): - return images, [False] * len(images) - - return check + assert cap_logger.out == "Keyword arguments {'not_used': True} not recognized.\n" def test_from_pretrained_save_pretrained(self): # 1. Load models @@ -1372,7 +660,7 @@ def test_from_pretrained_save_pretrained(self): new_ddpm = DDPMPipeline.from_pretrained(tmpdirname) new_ddpm.to(torch_device) - generator = torch.manual_seed(0) + generator = torch.Generator(device=torch_device).manual_seed(0) image = ddpm(generator=generator, output_type="numpy").images generator = generator.manual_seed(0) @@ -1380,301 +668,75 @@ def test_from_pretrained_save_pretrained(self): assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass" - @slow - def test_from_pretrained_hub(self): - model_path = "google/ddpm-cifar10-32" - - scheduler = DDPMScheduler(num_train_timesteps=10) - - ddpm = DDPMPipeline.from_pretrained(model_path, scheduler=scheduler) - ddpm.to(torch_device) - ddpm.set_progress_bar_config(disable=None) - ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler) - ddpm_from_hub.to(torch_device) - ddpm_from_hub.set_progress_bar_config(disable=None) - - generator = torch.manual_seed(0) - image = ddpm(generator=generator, output_type="numpy").images - - generator = generator.manual_seed(0) - new_image = ddpm_from_hub(generator=generator, output_type="numpy").images - - assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass" - - @slow - def test_from_pretrained_hub_pass_model(self): - model_path = "google/ddpm-cifar10-32" - - scheduler = DDPMScheduler(num_train_timesteps=10) - - # pass unet into DiffusionPipeline - unet = UNet2DModel.from_pretrained(model_path) - ddpm_from_hub_custom_model = DiffusionPipeline.from_pretrained(model_path, unet=unet, scheduler=scheduler) - ddpm_from_hub_custom_model.to(torch_device) - ddpm_from_hub_custom_model.set_progress_bar_config(disable=None) - - ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler) - ddpm_from_hub.to(torch_device) - ddpm_from_hub_custom_model.set_progress_bar_config(disable=None) - - generator = torch.manual_seed(0) - image = ddpm_from_hub_custom_model(generator=generator, output_type="numpy").images - - generator = generator.manual_seed(0) - new_image = ddpm_from_hub(generator=generator, output_type="numpy").images - - assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass" - - @slow - def test_output_format(self): - model_path = "google/ddpm-cifar10-32" - - pipe = DDIMPipeline.from_pretrained(model_path) - pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - - generator = torch.manual_seed(0) - images = pipe(generator=generator, output_type="numpy").images - assert images.shape == (1, 32, 32, 3) - assert isinstance(images, np.ndarray) - - images = pipe(generator=generator, output_type="pil").images - assert isinstance(images, list) - assert len(images) == 1 - assert isinstance(images[0], PIL.Image.Image) - - # use PIL by default - images = pipe(generator=generator).images - assert isinstance(images, list) - assert isinstance(images[0], PIL.Image.Image) - - @slow - def test_ddpm_cifar10(self): - model_id = "google/ddpm-cifar10-32" - - unet = UNet2DModel.from_pretrained(model_id) - scheduler = DDPMScheduler.from_config(model_id) - - ddpm = DDPMPipeline(unet=unet, scheduler=scheduler) - ddpm.to(torch_device) - ddpm.set_progress_bar_config(disable=None) - - generator = torch.manual_seed(0) - image = ddpm(generator=generator, output_type="numpy").images - - image_slice = image[0, -3:, -3:, -1] - - assert image.shape == (1, 32, 32, 3) - expected_slice = np.array([0.41995, 0.35885, 0.19385, 0.38475, 0.3382, 0.2647, 0.41545, 0.3582, 0.33845]) - assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 - - @slow - def test_ddim_lsun(self): - model_id = "google/ddpm-ema-bedroom-256" - - unet = UNet2DModel.from_pretrained(model_id) - scheduler = DDIMScheduler.from_config(model_id) - - ddpm = DDIMPipeline(unet=unet, scheduler=scheduler) - ddpm.to(torch_device) - ddpm.set_progress_bar_config(disable=None) - - generator = torch.manual_seed(0) - image = ddpm(generator=generator, output_type="numpy").images - - image_slice = image[0, -3:, -3:, -1] - - assert image.shape == (1, 256, 256, 3) - expected_slice = np.array([0.00605, 0.0201, 0.0344, 0.00235, 0.00185, 0.00025, 0.00215, 0.0, 0.00685]) - assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 - - @slow - def test_ddim_cifar10(self): - model_id = "google/ddpm-cifar10-32" - - unet = UNet2DModel.from_pretrained(model_id) - scheduler = DDIMScheduler() - - ddim = DDIMPipeline(unet=unet, scheduler=scheduler) - ddim.to(torch_device) - ddim.set_progress_bar_config(disable=None) - - generator = torch.manual_seed(0) - image = ddim(generator=generator, eta=0.0, output_type="numpy").images - - image_slice = image[0, -3:, -3:, -1] - - assert image.shape == (1, 32, 32, 3) - expected_slice = np.array([0.17235, 0.16175, 0.16005, 0.16255, 0.1497, 0.1513, 0.15045, 0.1442, 0.1453]) - assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 - - @slow - def test_pndm_cifar10(self): - model_id = "google/ddpm-cifar10-32" - - unet = UNet2DModel.from_pretrained(model_id) - scheduler = PNDMScheduler() - - pndm = PNDMPipeline(unet=unet, scheduler=scheduler) - pndm.to(torch_device) - pndm.set_progress_bar_config(disable=None) - generator = torch.manual_seed(0) - image = pndm(generator=generator, output_type="numpy").images - - image_slice = image[0, -3:, -3:, -1] - - assert image.shape == (1, 32, 32, 3) - expected_slice = np.array([0.1564, 0.14645, 0.1406, 0.14715, 0.12425, 0.14045, 0.13115, 0.12175, 0.125]) - assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 - - @slow - def test_ldm_text2img(self): - ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256") - ldm.to(torch_device) - ldm.set_progress_bar_config(disable=None) - - prompt = "A painting of a squirrel eating a burger" - generator = torch.manual_seed(0) - image = ldm( - [prompt], generator=generator, guidance_scale=6.0, num_inference_steps=20, output_type="numpy" - ).images - - image_slice = image[0, -3:, -3:, -1] - - assert image.shape == (1, 256, 256, 3) - expected_slice = np.array([0.9256, 0.9340, 0.8933, 0.9361, 0.9113, 0.8727, 0.9122, 0.8745, 0.8099]) - assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 - - @slow - def test_ldm_text2img_fast(self): - ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256") - ldm.to(torch_device) - ldm.set_progress_bar_config(disable=None) - - prompt = "A painting of a squirrel eating a burger" - generator = torch.manual_seed(0) - image = ldm(prompt, generator=generator, num_inference_steps=1, output_type="numpy").images - - image_slice = image[0, -3:, -3:, -1] - - assert image.shape == (1, 256, 256, 3) - expected_slice = np.array([0.3163, 0.8670, 0.6465, 0.1865, 0.6291, 0.5139, 0.2824, 0.3723, 0.4344]) - assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 - - @slow - @unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU") - def test_stable_diffusion(self): - # make sure here that pndm scheduler skips prk - sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1") - sd_pipe = sd_pipe.to(torch_device) - sd_pipe.set_progress_bar_config(disable=None) - - prompt = "A painting of a squirrel eating a burger" - generator = torch.Generator(device=torch_device).manual_seed(0) - with torch.autocast("cuda"): - output = sd_pipe( - [prompt], generator=generator, guidance_scale=6.0, num_inference_steps=20, output_type="np" - ) - - image = output.images - - image_slice = image[0, -3:, -3:, -1] - - assert image.shape == (1, 512, 512, 3) - expected_slice = np.array([0.8887, 0.915, 0.91, 0.894, 0.909, 0.912, 0.919, 0.925, 0.883]) - assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 - - @slow - @unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU") - def test_stable_diffusion_fast_ddim(self): - sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1") - sd_pipe = sd_pipe.to(torch_device) - sd_pipe.set_progress_bar_config(disable=None) - - scheduler = DDIMScheduler( - beta_start=0.00085, - beta_end=0.012, - beta_schedule="scaled_linear", - clip_sample=False, - set_alpha_to_one=False, - ) - sd_pipe.scheduler = scheduler - - prompt = "A painting of a squirrel eating a burger" - generator = torch.Generator(device=torch_device).manual_seed(0) - - with torch.autocast("cuda"): - output = sd_pipe([prompt], generator=generator, num_inference_steps=2, output_type="numpy") - image = output.images - - image_slice = image[0, -3:, -3:, -1] + def test_from_pretrained_hub(self): + model_path = "google/ddpm-cifar10-32" - assert image.shape == (1, 512, 512, 3) - expected_slice = np.array([0.9326, 0.923, 0.951, 0.9365, 0.9214, 0.951, 0.9365, 0.9414, 0.918]) - assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + scheduler = DDPMScheduler(num_train_timesteps=10) - @slow - def test_score_sde_ve_pipeline(self): - model_id = "google/ncsnpp-church-256" - model = UNet2DModel.from_pretrained(model_id) + ddpm = DDPMPipeline.from_pretrained(model_path, scheduler=scheduler) + ddpm = ddpm.to(torch_device) + ddpm.set_progress_bar_config(disable=None) - scheduler = ScoreSdeVeScheduler.from_config(model_id) + ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler) + ddpm_from_hub = ddpm_from_hub.to(torch_device) + ddpm_from_hub.set_progress_bar_config(disable=None) - sde_ve = ScoreSdeVePipeline(unet=model, scheduler=scheduler) - sde_ve.to(torch_device) - sde_ve.set_progress_bar_config(disable=None) + generator = torch.Generator(device=torch_device).manual_seed(0) + image = ddpm(generator=generator, output_type="numpy").images - generator = torch.manual_seed(0) - image = sde_ve(num_inference_steps=10, output_type="numpy", generator=generator).images + generator = generator.manual_seed(0) + new_image = ddpm_from_hub(generator=generator, output_type="numpy").images - image_slice = image[0, -3:, -3:, -1] + assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass" - assert image.shape == (1, 256, 256, 3) + def test_from_pretrained_hub_pass_model(self): + model_path = "google/ddpm-cifar10-32" - expected_slice = np.array([0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0]) - assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + scheduler = DDPMScheduler(num_train_timesteps=10) - @slow - def test_ldm_uncond(self): - ldm = LDMPipeline.from_pretrained("CompVis/ldm-celebahq-256") - ldm.to(torch_device) - ldm.set_progress_bar_config(disable=None) + # pass unet into DiffusionPipeline + unet = UNet2DModel.from_pretrained(model_path) + ddpm_from_hub_custom_model = DiffusionPipeline.from_pretrained(model_path, unet=unet, scheduler=scheduler) + ddpm_from_hub_custom_model = ddpm_from_hub_custom_model.to(torch_device) + ddpm_from_hub_custom_model.set_progress_bar_config(disable=None) - generator = torch.manual_seed(0) - image = ldm(generator=generator, num_inference_steps=5, output_type="numpy").images + ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler) + ddpm_from_hub = ddpm_from_hub.to(torch_device) + ddpm_from_hub_custom_model.set_progress_bar_config(disable=None) - image_slice = image[0, -3:, -3:, -1] + generator = torch.Generator(device=torch_device).manual_seed(0) + image = ddpm_from_hub_custom_model(generator=generator, output_type="numpy").images - assert image.shape == (1, 256, 256, 3) - expected_slice = np.array([0.4399, 0.44975, 0.46825, 0.474, 0.4359, 0.4581, 0.45095, 0.4341, 0.4447]) - assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + generator = generator.manual_seed(0) + new_image = ddpm_from_hub(generator=generator, output_type="numpy").images - @slow - def test_ddpm_ddim_equality(self): - model_id = "google/ddpm-cifar10-32" + assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass" - unet = UNet2DModel.from_pretrained(model_id) - ddpm_scheduler = DDPMScheduler() - ddim_scheduler = DDIMScheduler() + def test_output_format(self): + model_path = "google/ddpm-cifar10-32" - ddpm = DDPMPipeline(unet=unet, scheduler=ddpm_scheduler) - ddpm.to(torch_device) - ddpm.set_progress_bar_config(disable=None) - ddim = DDIMPipeline(unet=unet, scheduler=ddim_scheduler) - ddim.to(torch_device) - ddim.set_progress_bar_config(disable=None) + scheduler = DDIMScheduler.from_pretrained(model_path) + pipe = DDIMPipeline.from_pretrained(model_path, scheduler=scheduler) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) - generator = torch.manual_seed(0) - ddpm_image = ddpm(generator=generator, output_type="numpy").images + generator = torch.Generator(device=torch_device).manual_seed(0) + images = pipe(generator=generator, output_type="numpy").images + assert images.shape == (1, 32, 32, 3) + assert isinstance(images, np.ndarray) - generator = torch.manual_seed(0) - ddim_image = ddim(generator=generator, num_inference_steps=1000, eta=1.0, output_type="numpy").images + images = pipe(generator=generator, output_type="pil", num_inference_steps=4).images + assert isinstance(images, list) + assert len(images) == 1 + assert isinstance(images[0], PIL.Image.Image) - # the values aren't exactly equal, but the images look the same visually - assert np.abs(ddpm_image - ddim_image).max() < 1e-1 + # use PIL by default + images = pipe(generator=generator, num_inference_steps=4).images + assert isinstance(images, list) + assert isinstance(images[0], PIL.Image.Image) - @unittest.skip("(Anton) The test is failing for large batch sizes, needs investigation") def test_ddpm_ddim_equality_batched(self): + seed = 0 model_id = "google/ddpm-cifar10-32" unet = UNet2DModel.from_pretrained(model_id) @@ -1689,589 +751,18 @@ def test_ddpm_ddim_equality_batched(self): ddim.to(torch_device) ddim.set_progress_bar_config(disable=None) - generator = torch.manual_seed(0) - ddpm_images = ddpm(batch_size=4, generator=generator, output_type="numpy").images + generator = torch.Generator(device=torch_device).manual_seed(seed) + ddpm_images = ddpm(batch_size=2, generator=generator, output_type="numpy").images - generator = torch.manual_seed(0) + generator = torch.Generator(device=torch_device).manual_seed(seed) ddim_images = ddim( - batch_size=4, generator=generator, num_inference_steps=1000, eta=1.0, output_type="numpy" + batch_size=2, + generator=generator, + num_inference_steps=1000, + eta=1.0, + output_type="numpy", + use_clipped_model_output=True, # Need this to make DDIM match DDPM ).images # the values aren't exactly equal, but the images look the same visually assert np.abs(ddpm_images - ddim_images).max() < 1e-1 - - @slow - def test_karras_ve_pipeline(self): - model_id = "google/ncsnpp-celebahq-256" - model = UNet2DModel.from_pretrained(model_id) - scheduler = KarrasVeScheduler() - - pipe = KarrasVePipeline(unet=model, scheduler=scheduler) - pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - - generator = torch.manual_seed(0) - image = pipe(num_inference_steps=20, generator=generator, output_type="numpy").images - - image_slice = image[0, -3:, -3:, -1] - assert image.shape == (1, 256, 256, 3) - expected_slice = np.array([0.578, 0.5811, 0.5924, 0.5809, 0.587, 0.5886, 0.5861, 0.5802, 0.586]) - assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 - - @slow - @unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU") - def test_lms_stable_diffusion_pipeline(self): - model_id = "CompVis/stable-diffusion-v1-1" - pipe = StableDiffusionPipeline.from_pretrained(model_id).to(torch_device) - pipe.set_progress_bar_config(disable=None) - scheduler = LMSDiscreteScheduler.from_config(model_id, subfolder="scheduler") - pipe.scheduler = scheduler - - prompt = "a photograph of an astronaut riding a horse" - generator = torch.Generator(device=torch_device).manual_seed(0) - image = pipe( - [prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy" - ).images - - image_slice = image[0, -3:, -3:, -1] - assert image.shape == (1, 512, 512, 3) - expected_slice = np.array([0.9077, 0.9254, 0.9181, 0.9227, 0.9213, 0.9367, 0.9399, 0.9406, 0.9024]) - assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 - - @slow - @unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU") - def test_stable_diffusion_memory_chunking(self): - torch.cuda.reset_peak_memory_stats() - model_id = "CompVis/stable-diffusion-v1-4" - pipe = StableDiffusionPipeline.from_pretrained(model_id, revision="fp16", torch_dtype=torch.float16).to( - torch_device - ) - pipe.set_progress_bar_config(disable=None) - - prompt = "a photograph of an astronaut riding a horse" - - # make attention efficient - pipe.enable_attention_slicing() - generator = torch.Generator(device=torch_device).manual_seed(0) - with torch.autocast(torch_device): - output_chunked = pipe( - [prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy" - ) - image_chunked = output_chunked.images - - mem_bytes = torch.cuda.max_memory_allocated() - torch.cuda.reset_peak_memory_stats() - # make sure that less than 3.75 GB is allocated - assert mem_bytes < 3.75 * 10**9 - - # disable chunking - pipe.disable_attention_slicing() - generator = torch.Generator(device=torch_device).manual_seed(0) - with torch.autocast(torch_device): - output = pipe( - [prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy" - ) - image = output.images - - # make sure that more than 3.75 GB is allocated - mem_bytes = torch.cuda.max_memory_allocated() - assert mem_bytes > 3.75 * 10**9 - assert np.abs(image_chunked.flatten() - image.flatten()).max() < 1e-3 - - @slow - @unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU") - def test_stable_diffusion_text2img_pipeline_fp16(self): - torch.cuda.reset_peak_memory_stats() - model_id = "CompVis/stable-diffusion-v1-4" - pipe = StableDiffusionPipeline.from_pretrained(model_id, revision="fp16", torch_dtype=torch.float16).to( - torch_device - ) - pipe.set_progress_bar_config(disable=None) - - prompt = "a photograph of an astronaut riding a horse" - - generator = torch.Generator(device=torch_device).manual_seed(0) - output_chunked = pipe( - [prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy" - ) - image_chunked = output_chunked.images - - generator = torch.Generator(device=torch_device).manual_seed(0) - with torch.autocast(torch_device): - output = pipe( - [prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy" - ) - image = output.images - - # Make sure results are close enough - diff = np.abs(image_chunked.flatten() - image.flatten()) - # They ARE different since ops are not run always at the same precision - # however, they should be extremely close. - assert diff.mean() < 2e-2 - - @slow - @unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU") - def test_stable_diffusion_text2img_pipeline(self): - expected_image = load_image( - "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" - "/text2img/astronaut_riding_a_horse.png" - ) - expected_image = np.array(expected_image, dtype=np.float32) / 255.0 - - model_id = "CompVis/stable-diffusion-v1-4" - pipe = StableDiffusionPipeline.from_pretrained( - model_id, - safety_checker=self.dummy_safety_checker, - ) - pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - pipe.enable_attention_slicing() - - prompt = "astronaut riding a horse" - - generator = torch.Generator(device=torch_device).manual_seed(0) - output = pipe(prompt=prompt, strength=0.75, guidance_scale=7.5, generator=generator, output_type="np") - image = output.images[0] - - assert image.shape == (512, 512, 3) - assert np.abs(expected_image - image).max() < 1e-2 - - @slow - @unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU") - def test_stable_diffusion_img2img_pipeline(self): - init_image = load_image( - "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" - "/img2img/sketch-mountains-input.jpg" - ) - expected_image = load_image( - "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" - "/img2img/fantasy_landscape.png" - ) - init_image = init_image.resize((768, 512)) - expected_image = np.array(expected_image, dtype=np.float32) / 255.0 - - model_id = "CompVis/stable-diffusion-v1-4" - pipe = StableDiffusionImg2ImgPipeline.from_pretrained( - model_id, - safety_checker=self.dummy_safety_checker, - ) - pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - pipe.enable_attention_slicing() - - prompt = "A fantasy landscape, trending on artstation" - - generator = torch.Generator(device=torch_device).manual_seed(0) - output = pipe( - prompt=prompt, - init_image=init_image, - strength=0.75, - guidance_scale=7.5, - generator=generator, - output_type="np", - ) - image = output.images[0] - - assert image.shape == (512, 768, 3) - # img2img is flaky across GPUs even in fp32, so using MAE here - assert np.abs(expected_image - image).mean() < 1e-2 - - @slow - @unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU") - def test_stable_diffusion_img2img_pipeline_k_lms(self): - init_image = load_image( - "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" - "/img2img/sketch-mountains-input.jpg" - ) - expected_image = load_image( - "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" - "/img2img/fantasy_landscape_k_lms.png" - ) - init_image = init_image.resize((768, 512)) - expected_image = np.array(expected_image, dtype=np.float32) / 255.0 - - lms = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear") - - model_id = "CompVis/stable-diffusion-v1-4" - pipe = StableDiffusionImg2ImgPipeline.from_pretrained( - model_id, - scheduler=lms, - safety_checker=self.dummy_safety_checker, - ) - pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - pipe.enable_attention_slicing() - - prompt = "A fantasy landscape, trending on artstation" - - generator = torch.Generator(device=torch_device).manual_seed(0) - output = pipe( - prompt=prompt, - init_image=init_image, - strength=0.75, - guidance_scale=7.5, - generator=generator, - output_type="np", - ) - image = output.images[0] - - assert image.shape == (512, 768, 3) - # img2img is flaky across GPUs even in fp32, so using MAE here - assert np.abs(expected_image - image).mean() < 1e-2 - - @slow - @unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU") - def test_stable_diffusion_inpaint_pipeline(self): - init_image = load_image( - "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" - "/in_paint/overture-creations-5sI6fQgYIuo.png" - ) - mask_image = load_image( - "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" - "/in_paint/overture-creations-5sI6fQgYIuo_mask.png" - ) - expected_image = load_image( - "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" - "/in_paint/red_cat_sitting_on_a_park_bench.png" - ) - expected_image = np.array(expected_image, dtype=np.float32) / 255.0 - - model_id = "CompVis/stable-diffusion-v1-4" - pipe = StableDiffusionInpaintPipeline.from_pretrained( - model_id, - safety_checker=self.dummy_safety_checker, - ) - pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - pipe.enable_attention_slicing() - - prompt = "A red cat sitting on a park bench" - - generator = torch.Generator(device=torch_device).manual_seed(0) - output = pipe( - prompt=prompt, - init_image=init_image, - mask_image=mask_image, - strength=0.75, - guidance_scale=7.5, - generator=generator, - output_type="np", - ) - image = output.images[0] - - assert image.shape == (512, 512, 3) - assert np.abs(expected_image - image).max() < 1e-2 - - @slow - @unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU") - def test_stable_diffusion_inpaint_pipeline_k_lms(self): - init_image = load_image( - "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" - "/in_paint/overture-creations-5sI6fQgYIuo.png" - ) - mask_image = load_image( - "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" - "/in_paint/overture-creations-5sI6fQgYIuo_mask.png" - ) - expected_image = load_image( - "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" - "/in_paint/red_cat_sitting_on_a_park_bench_k_lms.png" - ) - expected_image = np.array(expected_image, dtype=np.float32) / 255.0 - - lms = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear") - - model_id = "CompVis/stable-diffusion-v1-4" - pipe = StableDiffusionInpaintPipeline.from_pretrained( - model_id, - scheduler=lms, - safety_checker=self.dummy_safety_checker, - ) - pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - pipe.enable_attention_slicing() - - prompt = "A red cat sitting on a park bench" - - generator = torch.Generator(device=torch_device).manual_seed(0) - output = pipe( - prompt=prompt, - init_image=init_image, - mask_image=mask_image, - strength=0.75, - guidance_scale=7.5, - generator=generator, - output_type="np", - ) - image = output.images[0] - - assert image.shape == (512, 512, 3) - assert np.abs(expected_image - image).max() < 1e-2 - - @slow - def test_stable_diffusion_onnx(self): - sd_pipe = StableDiffusionOnnxPipeline.from_pretrained( - "CompVis/stable-diffusion-v1-4", revision="onnx", provider="CPUExecutionProvider" - ) - - prompt = "A painting of a squirrel eating a burger" - np.random.seed(0) - output = sd_pipe([prompt], guidance_scale=6.0, num_inference_steps=5, output_type="np") - image = output.images - - image_slice = image[0, -3:, -3:, -1] - - assert image.shape == (1, 512, 512, 3) - expected_slice = np.array([0.3602, 0.3688, 0.3652, 0.3895, 0.3782, 0.3747, 0.3927, 0.4241, 0.4327]) - assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 - - @slow - @unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU") - def test_stable_diffusion_text2img_intermediate_state(self): - number_of_steps = 0 - - def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> None: - test_callback_fn.has_been_called = True - nonlocal number_of_steps - number_of_steps += 1 - if step == 0: - latents = latents.detach().cpu().numpy() - assert latents.shape == (1, 4, 64, 64) - latents_slice = latents[0, -3:, -3:, -1] - expected_slice = np.array( - [1.8285, 1.2857, -0.1024, 1.2406, -2.3068, 1.0747, -0.0818, -0.6520, -2.9506] - ) - assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3 - elif step == 50: - latents = latents.detach().cpu().numpy() - assert latents.shape == (1, 4, 64, 64) - latents_slice = latents[0, -3:, -3:, -1] - expected_slice = np.array( - [1.1078, 1.5803, 0.2773, -0.0589, -1.7928, -0.3665, -0.4695, -1.0727, -1.1601] - ) - assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-2 - - test_callback_fn.has_been_called = False - - pipe = StableDiffusionPipeline.from_pretrained( - "CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16 - ) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - pipe.enable_attention_slicing() - - prompt = "Andromeda galaxy in a bottle" - - generator = torch.Generator(device=torch_device).manual_seed(0) - with torch.autocast(torch_device): - pipe( - prompt=prompt, - num_inference_steps=50, - guidance_scale=7.5, - generator=generator, - callback=test_callback_fn, - callback_steps=1, - ) - assert test_callback_fn.has_been_called - assert number_of_steps == 51 - - @slow - @unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU") - def test_stable_diffusion_img2img_intermediate_state(self): - number_of_steps = 0 - - def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> None: - test_callback_fn.has_been_called = True - nonlocal number_of_steps - number_of_steps += 1 - if step == 0: - latents = latents.detach().cpu().numpy() - assert latents.shape == (1, 4, 64, 96) - latents_slice = latents[0, -3:, -3:, -1] - expected_slice = np.array([0.9052, -0.0184, 0.4810, 0.2898, 0.5851, 1.4920, 0.5362, 1.9838, 0.0530]) - assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3 - elif step == 37: - latents = latents.detach().cpu().numpy() - assert latents.shape == (1, 4, 64, 96) - latents_slice = latents[0, -3:, -3:, -1] - expected_slice = np.array([0.7071, 0.7831, 0.8300, 1.8140, 1.7840, 1.9402, 1.3651, 1.6590, 1.2828]) - assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-2 - - test_callback_fn.has_been_called = False - - init_image = load_image( - "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" - "/img2img/sketch-mountains-input.jpg" - ) - init_image = init_image.resize((768, 512)) - - pipe = StableDiffusionImg2ImgPipeline.from_pretrained( - "CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16 - ) - pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - pipe.enable_attention_slicing() - - prompt = "A fantasy landscape, trending on artstation" - - generator = torch.Generator(device=torch_device).manual_seed(0) - with torch.autocast(torch_device): - pipe( - prompt=prompt, - init_image=init_image, - strength=0.75, - num_inference_steps=50, - guidance_scale=7.5, - generator=generator, - callback=test_callback_fn, - callback_steps=1, - ) - assert test_callback_fn.has_been_called - assert number_of_steps == 38 - - @slow - @unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU") - def test_stable_diffusion_inpaint_intermediate_state(self): - number_of_steps = 0 - - def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> None: - test_callback_fn.has_been_called = True - nonlocal number_of_steps - number_of_steps += 1 - if step == 0: - latents = latents.detach().cpu().numpy() - assert latents.shape == (1, 4, 64, 64) - latents_slice = latents[0, -3:, -3:, -1] - expected_slice = np.array( - [-0.5472, 1.1218, -0.5505, -0.9390, -1.0794, 0.4063, 0.5158, 0.6429, -1.5246] - ) - assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3 - elif step == 37: - latents = latents.detach().cpu().numpy() - assert latents.shape == (1, 4, 64, 64) - latents_slice = latents[0, -3:, -3:, -1] - expected_slice = np.array([0.4781, 1.1572, 0.6258, 0.2291, 0.2554, -0.1443, 0.7085, -0.1598, -0.5659]) - assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3 - - test_callback_fn.has_been_called = False - - init_image = load_image( - "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" - "/in_paint/overture-creations-5sI6fQgYIuo.png" - ) - mask_image = load_image( - "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" - "/in_paint/overture-creations-5sI6fQgYIuo_mask.png" - ) - - pipe = StableDiffusionInpaintPipeline.from_pretrained( - "CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16 - ) - pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - pipe.enable_attention_slicing() - - prompt = "A red cat sitting on a park bench" - - generator = torch.Generator(device=torch_device).manual_seed(0) - with torch.autocast(torch_device): - pipe( - prompt=prompt, - init_image=init_image, - mask_image=mask_image, - strength=0.75, - num_inference_steps=50, - guidance_scale=7.5, - generator=generator, - callback=test_callback_fn, - callback_steps=1, - ) - assert test_callback_fn.has_been_called - assert number_of_steps == 38 - - @slow - def test_stable_diffusion_onnx_intermediate_state(self): - number_of_steps = 0 - - def test_callback_fn(step: int, timestep: int, latents: np.ndarray) -> None: - test_callback_fn.has_been_called = True - nonlocal number_of_steps - number_of_steps += 1 - if step == 0: - assert latents.shape == (1, 4, 64, 64) - latents_slice = latents[0, -3:, -3:, -1] - expected_slice = np.array( - [-0.5950, -0.3039, -1.1672, 0.1594, -1.1572, 0.6719, -1.9712, -0.0403, 0.9592] - ) - assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3 - elif step == 5: - assert latents.shape == (1, 4, 64, 64) - latents_slice = latents[0, -3:, -3:, -1] - expected_slice = np.array( - [-0.4776, -0.0119, -0.8519, -0.0275, -0.9764, 0.9820, -0.3843, 0.3788, 1.2264] - ) - assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3 - - test_callback_fn.has_been_called = False - - pipe = StableDiffusionOnnxPipeline.from_pretrained( - "CompVis/stable-diffusion-v1-4", revision="onnx", provider="CPUExecutionProvider" - ) - pipe.set_progress_bar_config(disable=None) - - prompt = "Andromeda galaxy in a bottle" - - np.random.seed(0) - pipe(prompt=prompt, num_inference_steps=5, guidance_scale=7.5, callback=test_callback_fn, callback_steps=1) - assert test_callback_fn.has_been_called - assert number_of_steps == 6 - - @slow - @unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU") - def test_stable_diffusion_accelerate_load_works(self): - if version.parse(version.parse(transformers.__version__).base_version) < version.parse("4.23"): - return - - if version.parse(version.parse(accelerate.__version__).base_version) < version.parse("0.14"): - return - - model_id = "CompVis/stable-diffusion-v1-4" - _ = StableDiffusionPipeline.from_pretrained( - model_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True, device_map="auto" - ).to(torch_device) - - @slow - @unittest.skipIf(torch_device == "cpu", "This test is supposed to run on GPU") - def test_stable_diffusion_accelerate_load_reduces_memory_footprint(self): - if version.parse(version.parse(transformers.__version__).base_version) < version.parse("4.23"): - return - - if version.parse(version.parse(accelerate.__version__).base_version) < version.parse("0.14"): - return - - pipeline_id = "CompVis/stable-diffusion-v1-4" - - torch.cuda.empty_cache() - gc.collect() - - tracemalloc.start() - pipeline_normal_load = StableDiffusionPipeline.from_pretrained( - pipeline_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True - ) - pipeline_normal_load.to(torch_device) - _, peak_normal = tracemalloc.get_traced_memory() - tracemalloc.stop() - - del pipeline_normal_load - torch.cuda.empty_cache() - gc.collect() - - tracemalloc.start() - _ = StableDiffusionPipeline.from_pretrained( - pipeline_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True, device_map="auto" - ) - _, peak_accelerate = tracemalloc.get_traced_memory() - - tracemalloc.stop() - - assert peak_accelerate < peak_normal diff --git a/tests/test_pipelines_common.py b/tests/test_pipelines_common.py new file mode 100644 index 000000000000..bf99925f1c4f --- /dev/null +++ b/tests/test_pipelines_common.py @@ -0,0 +1,12 @@ +from diffusers.utils.testing_utils import require_torch + + +@require_torch +class PipelineTesterMixin: + """ + This mixin is designed to be used with unittest.TestCase classes. + It provides a set of common tests for each PyTorch pipeline, e.g. saving and loading the pipeline, + equivalence of dict and tuple outputs, etc. + """ + + pass diff --git a/tests/test_pipelines_flax.py b/tests/test_pipelines_flax.py index 9256944815c7..9b9dcddd6060 100644 --- a/tests/test_pipelines_flax.py +++ b/tests/test_pipelines_flax.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os +import tempfile import unittest import numpy as np @@ -24,14 +26,31 @@ if is_flax_available(): import jax import jax.numpy as jnp - from diffusers import FlaxDDIMScheduler, FlaxStableDiffusionPipeline + from diffusers import FlaxDDIMScheduler, FlaxDiffusionPipeline, FlaxStableDiffusionPipeline from flax.jax_utils import replicate from flax.training.common_utils import shard from jax import pmap @require_flax +class DownloadTests(unittest.TestCase): + def test_download_only_pytorch(self): + with tempfile.TemporaryDirectory() as tmpdirname: + # pipeline has Flax weights + _ = FlaxDiffusionPipeline.from_pretrained( + "hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None, cache_dir=tmpdirname + ) + + all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname, os.listdir(tmpdirname)[0], "snapshots"))] + files = [item for sublist in all_root_files for item in sublist] + + # None of the downloaded files should be a PyTorch file even if we have some here: + # https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe/blob/main/unet/diffusion_pytorch_model.bin + assert not any(f.endswith(".bin") for f in files) + + @slow +@require_flax class FlaxPipelineTests(unittest.TestCase): def test_dummy_all_tpus(self): pipeline, params = FlaxStableDiffusionPipeline.from_pretrained( @@ -54,18 +73,19 @@ def test_dummy_all_tpus(self): # shard inputs and rng params = replicate(params) - prng_seed = jax.random.split(prng_seed, 8) + prng_seed = jax.random.split(prng_seed, num_samples) prompt_ids = shard(prompt_ids) images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images - assert images.shape == (8, 1, 64, 64, 3) - assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 4.151474)) < 1e-3 - assert np.abs((np.abs(images, dtype=np.float32).sum() - 49947.875)) < 5e-1 + assert images.shape == (num_samples, 1, 64, 64, 3) + if jax.device_count() == 8: + assert np.abs(np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 3.1111548) < 1e-3 + assert np.abs(np.abs(images, dtype=np.float32).sum() - 199746.95) < 5e-1 images_pil = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:]))) - assert len(images_pil) == 8 + assert len(images_pil) == num_samples def test_stable_diffusion_v1_4(self): pipeline, params = FlaxStableDiffusionPipeline.from_pretrained( @@ -88,14 +108,15 @@ def test_stable_diffusion_v1_4(self): # shard inputs and rng params = replicate(params) - prng_seed = jax.random.split(prng_seed, 8) + prng_seed = jax.random.split(prng_seed, num_samples) prompt_ids = shard(prompt_ids) images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images - assert images.shape == (8, 1, 512, 512, 3) - assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.05652401)) < 1e-3 - assert np.abs((np.abs(images, dtype=np.float32).sum() - 2383808.2)) < 5e-1 + assert images.shape == (num_samples, 1, 512, 512, 3) + if jax.device_count() == 8: + assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.05652401)) < 1e-3 + assert np.abs((np.abs(images, dtype=np.float32).sum() - 2383808.2)) < 5e-1 def test_stable_diffusion_v1_4_bfloat_16(self): pipeline, params = FlaxStableDiffusionPipeline.from_pretrained( @@ -118,14 +139,15 @@ def test_stable_diffusion_v1_4_bfloat_16(self): # shard inputs and rng params = replicate(params) - prng_seed = jax.random.split(prng_seed, 8) + prng_seed = jax.random.split(prng_seed, num_samples) prompt_ids = shard(prompt_ids) images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images - assert images.shape == (8, 1, 512, 512, 3) - assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.06652832)) < 1e-3 - assert np.abs((np.abs(images, dtype=np.float32).sum() - 2384849.8)) < 5e-1 + assert images.shape == (num_samples, 1, 512, 512, 3) + if jax.device_count() == 8: + assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.06652832)) < 1e-3 + assert np.abs((np.abs(images, dtype=np.float32).sum() - 2384849.8)) < 5e-1 def test_stable_diffusion_v1_4_bfloat_16_with_safety(self): pipeline, params = FlaxStableDiffusionPipeline.from_pretrained( @@ -146,14 +168,15 @@ def test_stable_diffusion_v1_4_bfloat_16_with_safety(self): # shard inputs and rng params = replicate(params) - prng_seed = jax.random.split(prng_seed, 8) + prng_seed = jax.random.split(prng_seed, num_samples) prompt_ids = shard(prompt_ids) images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images - assert images.shape == (8, 1, 512, 512, 3) - assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.06652832)) < 1e-3 - assert np.abs((np.abs(images, dtype=np.float32).sum() - 2384849.8)) < 5e-1 + assert images.shape == (num_samples, 1, 512, 512, 3) + if jax.device_count() == 8: + assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.06652832)) < 1e-3 + assert np.abs((np.abs(images, dtype=np.float32).sum() - 2384849.8)) < 5e-1 def test_stable_diffusion_v1_4_bfloat_16_ddim(self): scheduler = FlaxDDIMScheduler( @@ -191,11 +214,12 @@ def test_stable_diffusion_v1_4_bfloat_16_ddim(self): # shard inputs and rng params = replicate(params) - prng_seed = jax.random.split(prng_seed, 8) + prng_seed = jax.random.split(prng_seed, num_samples) prompt_ids = shard(prompt_ids) images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images - assert images.shape == (8, 1, 512, 512, 3) - assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.045043945)) < 1e-3 - assert np.abs((np.abs(images, dtype=np.float32).sum() - 2347693.5)) < 5e-1 + assert images.shape == (num_samples, 1, 512, 512, 3) + if jax.device_count() == 8: + assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.045043945)) < 1e-3 + assert np.abs((np.abs(images, dtype=np.float32).sum() - 2347693.5)) < 5e-1 diff --git a/tests/test_pipelines_onnx_common.py b/tests/test_pipelines_onnx_common.py new file mode 100644 index 000000000000..575ecd007531 --- /dev/null +++ b/tests/test_pipelines_onnx_common.py @@ -0,0 +1,12 @@ +from diffusers.utils.testing_utils import require_onnxruntime + + +@require_onnxruntime +class OnnxPipelineTesterMixin: + """ + This mixin is designed to be used with unittest.TestCase classes. + It provides a set of common tests for each ONNXRuntime pipeline, e.g. saving and loading the pipeline, + equivalence of dict and tuple outputs, etc. + """ + + pass diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index c3d4b9bc76f9..6a76581632ad 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -12,19 +12,203 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import inspect +import json +import os import tempfile import unittest from typing import Dict, List, Tuple import numpy as np import torch - -from diffusers import DDIMScheduler, DDPMScheduler, LMSDiscreteScheduler, PNDMScheduler, ScoreSdeVeScheduler +import torch.nn.functional as F + +import diffusers +from diffusers import ( + DDIMScheduler, + DDPMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + IPNDMScheduler, + LMSDiscreteScheduler, + PNDMScheduler, + ScoreSdeVeScheduler, + VQDiffusionScheduler, + logging, +) +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.schedulers.scheduling_utils import SchedulerMixin +from diffusers.utils import deprecate, torch_device +from diffusers.utils.testing_utils import CaptureLogger torch.backends.cuda.matmul.allow_tf32 = False +class SchedulerObject(SchedulerMixin, ConfigMixin): + config_name = "config.json" + + @register_to_config + def __init__( + self, + a=2, + b=5, + c=(2, 5), + d="for diffusion", + e=[1, 3], + ): + pass + + +class SchedulerObject2(SchedulerMixin, ConfigMixin): + config_name = "config.json" + + @register_to_config + def __init__( + self, + a=2, + b=5, + c=(2, 5), + d="for diffusion", + f=[1, 3], + ): + pass + + +class SchedulerObject3(SchedulerMixin, ConfigMixin): + config_name = "config.json" + + @register_to_config + def __init__( + self, + a=2, + b=5, + c=(2, 5), + d="for diffusion", + e=[1, 3], + f=[1, 3], + ): + pass + + +class SchedulerBaseTests(unittest.TestCase): + def test_save_load_from_different_config(self): + obj = SchedulerObject() + + # mock add obj class to `diffusers` + setattr(diffusers, "SchedulerObject", SchedulerObject) + logger = logging.get_logger("diffusers.configuration_utils") + + with tempfile.TemporaryDirectory() as tmpdirname: + obj.save_config(tmpdirname) + with CaptureLogger(logger) as cap_logger_1: + config = SchedulerObject2.load_config(tmpdirname) + new_obj_1 = SchedulerObject2.from_config(config) + + # now save a config parameter that is not expected + with open(os.path.join(tmpdirname, SchedulerObject.config_name), "r") as f: + data = json.load(f) + data["unexpected"] = True + + with open(os.path.join(tmpdirname, SchedulerObject.config_name), "w") as f: + json.dump(data, f) + + with CaptureLogger(logger) as cap_logger_2: + config = SchedulerObject.load_config(tmpdirname) + new_obj_2 = SchedulerObject.from_config(config) + + with CaptureLogger(logger) as cap_logger_3: + config = SchedulerObject2.load_config(tmpdirname) + new_obj_3 = SchedulerObject2.from_config(config) + + assert new_obj_1.__class__ == SchedulerObject2 + assert new_obj_2.__class__ == SchedulerObject + assert new_obj_3.__class__ == SchedulerObject2 + + assert cap_logger_1.out == "" + assert ( + cap_logger_2.out + == "The config attributes {'unexpected': True} were passed to SchedulerObject, but are not expected and" + " will" + " be ignored. Please verify your config.json configuration file.\n" + ) + assert cap_logger_2.out.replace("SchedulerObject", "SchedulerObject2") == cap_logger_3.out + + def test_save_load_compatible_schedulers(self): + SchedulerObject2._compatibles = ["SchedulerObject"] + SchedulerObject._compatibles = ["SchedulerObject2"] + + obj = SchedulerObject() + + # mock add obj class to `diffusers` + setattr(diffusers, "SchedulerObject", SchedulerObject) + setattr(diffusers, "SchedulerObject2", SchedulerObject2) + logger = logging.get_logger("diffusers.configuration_utils") + + with tempfile.TemporaryDirectory() as tmpdirname: + obj.save_config(tmpdirname) + + # now save a config parameter that is expected by another class, but not origin class + with open(os.path.join(tmpdirname, SchedulerObject.config_name), "r") as f: + data = json.load(f) + data["f"] = [0, 0] + data["unexpected"] = True + + with open(os.path.join(tmpdirname, SchedulerObject.config_name), "w") as f: + json.dump(data, f) + + with CaptureLogger(logger) as cap_logger: + config = SchedulerObject.load_config(tmpdirname) + new_obj = SchedulerObject.from_config(config) + + assert new_obj.__class__ == SchedulerObject + + assert ( + cap_logger.out + == "The config attributes {'unexpected': True} were passed to SchedulerObject, but are not expected and" + " will" + " be ignored. Please verify your config.json configuration file.\n" + ) + + def test_save_load_from_different_config_comp_schedulers(self): + SchedulerObject3._compatibles = ["SchedulerObject", "SchedulerObject2"] + SchedulerObject2._compatibles = ["SchedulerObject", "SchedulerObject3"] + SchedulerObject._compatibles = ["SchedulerObject2", "SchedulerObject3"] + + obj = SchedulerObject() + + # mock add obj class to `diffusers` + setattr(diffusers, "SchedulerObject", SchedulerObject) + setattr(diffusers, "SchedulerObject2", SchedulerObject2) + setattr(diffusers, "SchedulerObject3", SchedulerObject3) + logger = logging.get_logger("diffusers.configuration_utils") + logger.setLevel(diffusers.logging.INFO) + + with tempfile.TemporaryDirectory() as tmpdirname: + obj.save_config(tmpdirname) + + with CaptureLogger(logger) as cap_logger_1: + config = SchedulerObject.load_config(tmpdirname) + new_obj_1 = SchedulerObject.from_config(config) + + with CaptureLogger(logger) as cap_logger_2: + config = SchedulerObject2.load_config(tmpdirname) + new_obj_2 = SchedulerObject2.from_config(config) + + with CaptureLogger(logger) as cap_logger_3: + config = SchedulerObject3.load_config(tmpdirname) + new_obj_3 = SchedulerObject3.from_config(config) + + assert new_obj_1.__class__ == SchedulerObject + assert new_obj_2.__class__ == SchedulerObject2 + assert new_obj_3.__class__ == SchedulerObject3 + + assert cap_logger_1.out == "" + assert cap_logger_2.out == "{'f'} was not found in config. Values will be initialized to default values.\n" + assert cap_logger_3.out == "{'f'} was not found in config. Values will be initialized to default values.\n" + + class SchedulerCommonTest(unittest.TestCase): scheduler_classes = () forward_default_kwargs = () @@ -70,15 +254,25 @@ def check_over_configs(self, time_step=0, **config): num_inference_steps = kwargs.pop("num_inference_steps", None) for scheduler_class in self.scheduler_classes: - sample = self.dummy_sample - residual = 0.1 * sample + # TODO(Suraj) - delete the following two lines once DDPM, DDIM, and PNDM have timesteps casted to float by default + if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler): + time_step = float(time_step) scheduler_config = self.get_scheduler_config(**config) scheduler = scheduler_class(**scheduler_config) + if scheduler_class == VQDiffusionScheduler: + num_vec_classes = scheduler_config["num_vec_classes"] + sample = self.dummy_sample(num_vec_classes) + model = self.dummy_model(num_vec_classes) + residual = model(sample, time_step) + else: + sample = self.dummy_sample + residual = 0.1 * sample + with tempfile.TemporaryDirectory() as tmpdirname: scheduler.save_config(tmpdirname) - new_scheduler = scheduler_class.from_config(tmpdirname) + new_scheduler = scheduler_class.from_pretrained(tmpdirname) if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): scheduler.set_timesteps(num_inference_steps) @@ -86,7 +280,13 @@ def check_over_configs(self, time_step=0, **config): elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): kwargs["num_inference_steps"] = num_inference_steps + # Set the seed before step() as some schedulers are stochastic like EulerAncestralDiscreteScheduler, EulerDiscreteScheduler + if "generator" in set(inspect.signature(scheduler.step).parameters.keys()): + kwargs["generator"] = torch.Generator().manual_seed(0) output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample + + if "generator" in set(inspect.signature(scheduler.step).parameters.keys()): + kwargs["generator"] = torch.Generator().manual_seed(0) new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" @@ -98,15 +298,24 @@ def check_over_forward(self, time_step=0, **forward_kwargs): num_inference_steps = kwargs.pop("num_inference_steps", None) for scheduler_class in self.scheduler_classes: - sample = self.dummy_sample - residual = 0.1 * sample + if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler): + time_step = float(time_step) scheduler_config = self.get_scheduler_config() scheduler = scheduler_class(**scheduler_config) + if scheduler_class == VQDiffusionScheduler: + num_vec_classes = scheduler_config["num_vec_classes"] + sample = self.dummy_sample(num_vec_classes) + model = self.dummy_model(num_vec_classes) + residual = model(sample, time_step) + else: + sample = self.dummy_sample + residual = 0.1 * sample + with tempfile.TemporaryDirectory() as tmpdirname: scheduler.save_config(tmpdirname) - new_scheduler = scheduler_class.from_config(tmpdirname) + new_scheduler = scheduler_class.from_pretrained(tmpdirname) if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): scheduler.set_timesteps(num_inference_steps) @@ -114,9 +323,12 @@ def check_over_forward(self, time_step=0, **forward_kwargs): elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): kwargs["num_inference_steps"] = num_inference_steps - torch.manual_seed(0) + if "generator" in set(inspect.signature(scheduler.step).parameters.keys()): + kwargs["generator"] = torch.Generator().manual_seed(0) output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample - torch.manual_seed(0) + + if "generator" in set(inspect.signature(scheduler.step).parameters.keys()): + kwargs["generator"] = torch.Generator().manual_seed(0) new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" @@ -127,15 +339,25 @@ def test_from_pretrained_save_pretrained(self): num_inference_steps = kwargs.pop("num_inference_steps", None) for scheduler_class in self.scheduler_classes: - sample = self.dummy_sample - residual = 0.1 * sample + timestep = 1 + if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler): + timestep = float(timestep) scheduler_config = self.get_scheduler_config() scheduler = scheduler_class(**scheduler_config) + if scheduler_class == VQDiffusionScheduler: + num_vec_classes = scheduler_config["num_vec_classes"] + sample = self.dummy_sample(num_vec_classes) + model = self.dummy_model(num_vec_classes) + residual = model(sample, timestep) + else: + sample = self.dummy_sample + residual = 0.1 * sample + with tempfile.TemporaryDirectory() as tmpdirname: scheduler.save_config(tmpdirname) - new_scheduler = scheduler_class.from_config(tmpdirname) + new_scheduler = scheduler_class.from_pretrained(tmpdirname) if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): scheduler.set_timesteps(num_inference_steps) @@ -143,32 +365,84 @@ def test_from_pretrained_save_pretrained(self): elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): kwargs["num_inference_steps"] = num_inference_steps - torch.manual_seed(0) - output = scheduler.step(residual, 1, sample, **kwargs).prev_sample - torch.manual_seed(0) - new_output = new_scheduler.step(residual, 1, sample, **kwargs).prev_sample + if "generator" in set(inspect.signature(scheduler.step).parameters.keys()): + kwargs["generator"] = torch.Generator().manual_seed(0) + output = scheduler.step(residual, timestep, sample, **kwargs).prev_sample + + if "generator" in set(inspect.signature(scheduler.step).parameters.keys()): + kwargs["generator"] = torch.Generator().manual_seed(0) + new_output = new_scheduler.step(residual, timestep, sample, **kwargs).prev_sample assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + def test_compatibles(self): + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config() + + scheduler = scheduler_class(**scheduler_config) + + assert all(c is not None for c in scheduler.compatibles) + + for comp_scheduler_cls in scheduler.compatibles: + comp_scheduler = comp_scheduler_cls.from_config(scheduler.config) + assert comp_scheduler is not None + + new_scheduler = scheduler_class.from_config(comp_scheduler.config) + + new_scheduler_config = {k: v for k, v in new_scheduler.config.items() if k in scheduler.config} + scheduler_diff = {k: v for k, v in new_scheduler.config.items() if k not in scheduler.config} + + # make sure that configs are essentially identical + assert new_scheduler_config == dict(scheduler.config) + + # make sure that only differences are for configs that are not in init + init_keys = inspect.signature(scheduler_class.__init__).parameters.keys() + assert set(scheduler_diff.keys()).intersection(set(init_keys)) == set() + + def test_from_pretrained(self): + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config() + + scheduler = scheduler_class(**scheduler_config) + + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_pretrained(tmpdirname) + new_scheduler = scheduler_class.from_pretrained(tmpdirname) + + assert scheduler.config == new_scheduler.config + def test_step_shape(self): kwargs = dict(self.forward_default_kwargs) num_inference_steps = kwargs.pop("num_inference_steps", None) + timestep_0 = 0 + timestep_1 = 1 + for scheduler_class in self.scheduler_classes: + if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler): + timestep_0 = float(timestep_0) + timestep_1 = float(timestep_1) + scheduler_config = self.get_scheduler_config() scheduler = scheduler_class(**scheduler_config) - sample = self.dummy_sample - residual = 0.1 * sample + if scheduler_class == VQDiffusionScheduler: + num_vec_classes = scheduler_config["num_vec_classes"] + sample = self.dummy_sample(num_vec_classes) + model = self.dummy_model(num_vec_classes) + residual = model(sample, timestep_0) + else: + sample = self.dummy_sample + residual = 0.1 * sample if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): scheduler.set_timesteps(num_inference_steps) elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): kwargs["num_inference_steps"] = num_inference_steps - output_0 = scheduler.step(residual, 0, sample, **kwargs).prev_sample - output_1 = scheduler.step(residual, 1, sample, **kwargs).prev_sample + output_0 = scheduler.step(residual, timestep_0, sample, **kwargs).prev_sample + output_1 = scheduler.step(residual, timestep_1, sample, **kwargs).prev_sample self.assertEqual(output_0.shape, sample.shape) self.assertEqual(output_0.shape, output_1.shape) @@ -203,26 +477,45 @@ def recursive_check(tuple_object, dict_object): kwargs = dict(self.forward_default_kwargs) num_inference_steps = kwargs.pop("num_inference_steps", 50) + timestep = 0 + if len(self.scheduler_classes) > 0 and self.scheduler_classes[0] == IPNDMScheduler: + timestep = 1 + for scheduler_class in self.scheduler_classes: + if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler): + timestep = float(timestep) + scheduler_config = self.get_scheduler_config() scheduler = scheduler_class(**scheduler_config) - sample = self.dummy_sample - residual = 0.1 * sample + if scheduler_class == VQDiffusionScheduler: + num_vec_classes = scheduler_config["num_vec_classes"] + sample = self.dummy_sample(num_vec_classes) + model = self.dummy_model(num_vec_classes) + residual = model(sample, timestep) + else: + sample = self.dummy_sample + residual = 0.1 * sample if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): scheduler.set_timesteps(num_inference_steps) elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): kwargs["num_inference_steps"] = num_inference_steps - outputs_dict = scheduler.step(residual, 0, sample, **kwargs) + # Set the seed before state as some schedulers are stochastic like EulerAncestralDiscreteScheduler, EulerDiscreteScheduler + if "generator" in set(inspect.signature(scheduler.step).parameters.keys()): + kwargs["generator"] = torch.Generator().manual_seed(0) + outputs_dict = scheduler.step(residual, timestep, sample, **kwargs) if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): scheduler.set_timesteps(num_inference_steps) elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): kwargs["num_inference_steps"] = num_inference_steps - outputs_tuple = scheduler.step(residual, 0, sample, return_dict=False, **kwargs) + # Set the seed before state as some schedulers are stochastic like EulerAncestralDiscreteScheduler, EulerDiscreteScheduler + if "generator" in set(inspect.signature(scheduler.step).parameters.keys()): + kwargs["generator"] = torch.Generator().manual_seed(0) + outputs_tuple = scheduler.step(residual, timestep, sample, return_dict=False, **kwargs) recursive_check(outputs_tuple, outputs_dict) @@ -230,23 +523,66 @@ def test_scheduler_public_api(self): for scheduler_class in self.scheduler_classes: scheduler_config = self.get_scheduler_config() scheduler = scheduler_class(**scheduler_config) - self.assertTrue( - hasattr(scheduler, "init_noise_sigma"), - f"{scheduler_class} does not implement a required attribute `init_noise_sigma`", - ) - self.assertTrue( - hasattr(scheduler, "scale_model_input"), - f"{scheduler_class} does not implement a required class method `scale_model_input(sample, timestep)`", - ) + + if scheduler_class != VQDiffusionScheduler: + self.assertTrue( + hasattr(scheduler, "init_noise_sigma"), + f"{scheduler_class} does not implement a required attribute `init_noise_sigma`", + ) + self.assertTrue( + hasattr(scheduler, "scale_model_input"), + f"{scheduler_class} does not implement a required class method `scale_model_input(sample," + " timestep)`", + ) self.assertTrue( hasattr(scheduler, "step"), f"{scheduler_class} does not implement a required class method `step(...)`", ) - sample = self.dummy_sample + if scheduler_class != VQDiffusionScheduler: + sample = self.dummy_sample + scaled_sample = scheduler.scale_model_input(sample, 0.0) + self.assertEqual(sample.shape, scaled_sample.shape) + + def test_add_noise_device(self): + for scheduler_class in self.scheduler_classes: + if scheduler_class == IPNDMScheduler: + # Skip until #990 is addressed + continue + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + scheduler.set_timesteps(100) + + sample = self.dummy_sample.to(torch_device) scaled_sample = scheduler.scale_model_input(sample, 0.0) self.assertEqual(sample.shape, scaled_sample.shape) + noise = torch.randn_like(scaled_sample).to(torch_device) + t = scheduler.timesteps[5][None] + noised = scheduler.add_noise(scaled_sample, noise, t) + self.assertEqual(noised.shape, scaled_sample.shape) + + def test_deprecated_kwargs(self): + for scheduler_class in self.scheduler_classes: + has_kwarg_in_model_class = "kwargs" in inspect.signature(scheduler_class.__init__).parameters + has_deprecated_kwarg = len(scheduler_class._deprecated_kwargs) > 0 + + if has_kwarg_in_model_class and not has_deprecated_kwarg: + raise ValueError( + f"{scheduler_class} has `**kwargs` in its __init__ method but has not defined any deprecated" + " kwargs under the `_deprecated_kwargs` class attribute. Make sure to either remove `**kwargs` if" + " there are no deprecated arguments or add the deprecated argument with `_deprecated_kwargs =" + " []`" + ) + + if not has_kwarg_in_model_class and has_deprecated_kwarg: + raise ValueError( + f"{scheduler_class} doesn't have `**kwargs` in its __init__ method but has defined deprecated" + " kwargs under the `_deprecated_kwargs` class attribute. Make sure to either add the `**kwargs`" + f" argument to {self.model_class}.__init__ if there are deprecated arguments or remove the" + " deprecated argument from `_deprecated_kwargs = []`" + ) + class DDPMSchedulerTest(SchedulerCommonTest): scheduler_classes = (DDPMScheduler,) @@ -284,6 +620,39 @@ def test_clip_sample(self): for clip_sample in [True, False]: self.check_over_configs(clip_sample=clip_sample) + def test_prediction_type(self): + for prediction_type in ["epsilon", "sample"]: + self.check_over_configs(prediction_type=prediction_type) + + def test_deprecated_predict_epsilon(self): + deprecate("remove this test", "0.10.0", "remove") + for predict_epsilon in [True, False]: + self.check_over_configs(predict_epsilon=predict_epsilon) + + def test_deprecated_epsilon(self): + deprecate("remove this test", "0.10.0", "remove") + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + + sample = self.dummy_sample_deter + residual = 0.1 * self.dummy_sample_deter + time_step = 4 + + scheduler = scheduler_class(**scheduler_config) + scheduler_eps = scheduler_class(predict_epsilon=False, **scheduler_config) + + kwargs = {} + if "generator" in set(inspect.signature(scheduler.step).parameters.keys()): + kwargs["generator"] = torch.Generator().manual_seed(0) + output = scheduler.step(residual, time_step, sample, predict_epsilon=False, **kwargs).prev_sample + + kwargs = {} + if "generator" in set(inspect.signature(scheduler.step).parameters.keys()): + kwargs["generator"] = torch.Generator().manual_seed(0) + output_eps = scheduler_eps.step(residual, time_step, sample, predict_epsilon=False, **kwargs).prev_sample + + assert (output - output_eps).abs().sum() < 1e-5 + def test_time_indices(self): for t in [0, 500, 999]: self.check_over_forward(time_step=t) @@ -441,6 +810,187 @@ def test_full_loop_with_no_set_alpha_to_one(self): assert abs(result_mean.item() - 0.1941) < 1e-3 +class DPMSolverMultistepSchedulerTest(SchedulerCommonTest): + scheduler_classes = (DPMSolverMultistepScheduler,) + forward_default_kwargs = (("num_inference_steps", 25),) + + def get_scheduler_config(self, **kwargs): + config = { + "num_train_timesteps": 1000, + "beta_start": 0.0001, + "beta_end": 0.02, + "beta_schedule": "linear", + "solver_order": 2, + "prediction_type": "epsilon", + "thresholding": False, + "sample_max_value": 1.0, + "algorithm_type": "dpmsolver++", + "solver_type": "midpoint", + "lower_order_final": False, + } + + config.update(**kwargs) + return config + + def check_over_configs(self, time_step=0, **config): + kwargs = dict(self.forward_default_kwargs) + num_inference_steps = kwargs.pop("num_inference_steps", None) + sample = self.dummy_sample + residual = 0.1 * sample + dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.10] + + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config(**config) + scheduler = scheduler_class(**scheduler_config) + scheduler.set_timesteps(num_inference_steps) + # copy over dummy past residuals + scheduler.model_outputs = dummy_past_residuals[: scheduler.config.solver_order] + + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_config(tmpdirname) + new_scheduler = scheduler_class.from_pretrained(tmpdirname) + new_scheduler.set_timesteps(num_inference_steps) + # copy over dummy past residuals + new_scheduler.model_outputs = dummy_past_residuals[: new_scheduler.config.solver_order] + + output, new_output = sample, sample + for t in range(time_step, time_step + scheduler.config.solver_order + 1): + output = scheduler.step(residual, t, output, **kwargs).prev_sample + new_output = new_scheduler.step(residual, t, new_output, **kwargs).prev_sample + + assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + + def test_from_pretrained_save_pretrained(self): + pass + + def check_over_forward(self, time_step=0, **forward_kwargs): + kwargs = dict(self.forward_default_kwargs) + num_inference_steps = kwargs.pop("num_inference_steps", None) + sample = self.dummy_sample + residual = 0.1 * sample + dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.10] + + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + scheduler.set_timesteps(num_inference_steps) + + # copy over dummy past residuals (must be after setting timesteps) + scheduler.model_outputs = dummy_past_residuals[: scheduler.config.solver_order] + + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_config(tmpdirname) + new_scheduler = scheduler_class.from_pretrained(tmpdirname) + # copy over dummy past residuals + new_scheduler.set_timesteps(num_inference_steps) + + # copy over dummy past residual (must be after setting timesteps) + new_scheduler.model_outputs = dummy_past_residuals[: new_scheduler.config.solver_order] + + output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample + new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample + + assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + + def full_loop(self, **config): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config(**config) + scheduler = scheduler_class(**scheduler_config) + + num_inference_steps = 10 + model = self.dummy_model() + sample = self.dummy_sample_deter + scheduler.set_timesteps(num_inference_steps) + + for i, t in enumerate(scheduler.timesteps): + residual = model(sample, t) + sample = scheduler.step(residual, t, sample).prev_sample + + return sample + + def test_step_shape(self): + kwargs = dict(self.forward_default_kwargs) + + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + sample = self.dummy_sample + residual = 0.1 * sample + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + scheduler.set_timesteps(num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + # copy over dummy past residuals (must be done after set_timesteps) + dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.10] + scheduler.model_outputs = dummy_past_residuals[: scheduler.config.solver_order] + + time_step_0 = scheduler.timesteps[5] + time_step_1 = scheduler.timesteps[6] + + output_0 = scheduler.step(residual, time_step_0, sample, **kwargs).prev_sample + output_1 = scheduler.step(residual, time_step_1, sample, **kwargs).prev_sample + + self.assertEqual(output_0.shape, sample.shape) + self.assertEqual(output_0.shape, output_1.shape) + + def test_timesteps(self): + for timesteps in [25, 50, 100, 999, 1000]: + self.check_over_configs(num_train_timesteps=timesteps) + + def test_thresholding(self): + self.check_over_configs(thresholding=False) + for order in [1, 2, 3]: + for solver_type in ["midpoint", "heun"]: + for threshold in [0.5, 1.0, 2.0]: + for prediction_type in ["epsilon", "sample"]: + self.check_over_configs( + thresholding=True, + prediction_type=prediction_type, + sample_max_value=threshold, + algorithm_type="dpmsolver++", + solver_order=order, + solver_type=solver_type, + ) + + def test_solver_order_and_type(self): + for algorithm_type in ["dpmsolver", "dpmsolver++"]: + for solver_type in ["midpoint", "heun"]: + for order in [1, 2, 3]: + for prediction_type in ["epsilon", "sample"]: + self.check_over_configs( + solver_order=order, + solver_type=solver_type, + prediction_type=prediction_type, + algorithm_type=algorithm_type, + ) + sample = self.full_loop( + solver_order=order, + solver_type=solver_type, + prediction_type=prediction_type, + algorithm_type=algorithm_type, + ) + assert not torch.isnan(sample).any(), "Samples have nan numbers" + + def test_lower_order_final(self): + self.check_over_configs(lower_order_final=True) + self.check_over_configs(lower_order_final=False) + + def test_inference_steps(self): + for num_inference_steps in [1, 2, 3, 5, 10, 50, 100, 999, 1000]: + self.check_over_forward(num_inference_steps=num_inference_steps, time_step=0) + + def test_full_loop_no_noise(self): + sample = self.full_loop() + result_mean = torch.mean(torch.abs(sample)) + + assert abs(result_mean.item() - 0.3301) < 1e-3 + + class PNDMSchedulerTest(SchedulerCommonTest): scheduler_classes = (PNDMScheduler,) forward_default_kwargs = (("num_inference_steps", 50),) @@ -472,7 +1022,7 @@ def check_over_configs(self, time_step=0, **config): with tempfile.TemporaryDirectory() as tmpdirname: scheduler.save_config(tmpdirname) - new_scheduler = scheduler_class.from_config(tmpdirname) + new_scheduler = scheduler_class.from_pretrained(tmpdirname) new_scheduler.set_timesteps(num_inference_steps) # copy over dummy past residuals new_scheduler.ets = dummy_past_residuals[:] @@ -507,7 +1057,7 @@ def check_over_forward(self, time_step=0, **forward_kwargs): with tempfile.TemporaryDirectory() as tmpdirname: scheduler.save_config(tmpdirname) - new_scheduler = scheduler_class.from_config(tmpdirname) + new_scheduler = scheduler_class.from_pretrained(tmpdirname) # copy over dummy past residuals new_scheduler.set_timesteps(num_inference_steps) @@ -725,7 +1275,7 @@ def check_over_configs(self, time_step=0, **config): with tempfile.TemporaryDirectory() as tmpdirname: scheduler.save_config(tmpdirname) - new_scheduler = scheduler_class.from_config(tmpdirname) + new_scheduler = scheduler_class.from_pretrained(tmpdirname) output = scheduler.step_pred( residual, time_step, sample, generator=torch.manual_seed(0), **kwargs @@ -756,7 +1306,7 @@ def check_over_forward(self, time_step=0, **forward_kwargs): with tempfile.TemporaryDirectory() as tmpdirname: scheduler.save_config(tmpdirname) - new_scheduler = scheduler_class.from_config(tmpdirname) + new_scheduler = scheduler_class.from_pretrained(tmpdirname) output = scheduler.step_pred( residual, time_step, sample, generator=torch.manual_seed(0), **kwargs @@ -901,3 +1451,428 @@ def test_full_loop_no_noise(self): assert abs(result_sum.item() - 1006.388) < 1e-2 assert abs(result_mean.item() - 1.31) < 1e-3 + + def test_full_loop_device(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + scheduler.set_timesteps(self.num_inference_steps, device=torch_device) + + model = self.dummy_model() + sample = self.dummy_sample_deter * scheduler.init_noise_sigma + sample = sample.to(torch_device) + + for i, t in enumerate(scheduler.timesteps): + sample = scheduler.scale_model_input(sample, t) + + model_output = model(sample, t) + + output = scheduler.step(model_output, t, sample) + sample = output.prev_sample + + result_sum = torch.sum(torch.abs(sample)) + result_mean = torch.mean(torch.abs(sample)) + + assert abs(result_sum.item() - 1006.388) < 1e-2 + assert abs(result_mean.item() - 1.31) < 1e-3 + + +class EulerDiscreteSchedulerTest(SchedulerCommonTest): + scheduler_classes = (EulerDiscreteScheduler,) + num_inference_steps = 10 + + def get_scheduler_config(self, **kwargs): + config = { + "num_train_timesteps": 1100, + "beta_start": 0.0001, + "beta_end": 0.02, + "beta_schedule": "linear", + "trained_betas": None, + } + + config.update(**kwargs) + return config + + def test_timesteps(self): + for timesteps in [10, 50, 100, 1000]: + self.check_over_configs(num_train_timesteps=timesteps) + + def test_betas(self): + for beta_start, beta_end in zip([0.00001, 0.0001, 0.001], [0.0002, 0.002, 0.02]): + self.check_over_configs(beta_start=beta_start, beta_end=beta_end) + + def test_schedules(self): + for schedule in ["linear", "scaled_linear"]: + self.check_over_configs(beta_schedule=schedule) + + def test_full_loop_no_noise(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + scheduler.set_timesteps(self.num_inference_steps) + + if torch_device == "mps": + # device type MPS is not supported for torch.Generator() api. + generator = torch.manual_seed(0) + else: + generator = torch.Generator(device=torch_device).manual_seed(0) + + model = self.dummy_model() + sample = self.dummy_sample_deter * scheduler.init_noise_sigma + sample = sample.to(torch_device) + + for i, t in enumerate(scheduler.timesteps): + sample = scheduler.scale_model_input(sample, t) + + model_output = model(sample, t) + + output = scheduler.step(model_output, t, sample, generator=generator) + sample = output.prev_sample + + result_sum = torch.sum(torch.abs(sample)) + result_mean = torch.mean(torch.abs(sample)) + + assert abs(result_sum.item() - 10.0807) < 1e-2 + assert abs(result_mean.item() - 0.0131) < 1e-3 + + def test_full_loop_device(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + scheduler.set_timesteps(self.num_inference_steps, device=torch_device) + + if torch_device == "mps": + # device type MPS is not supported for torch.Generator() api. + generator = torch.manual_seed(0) + else: + generator = torch.Generator(device=torch_device).manual_seed(0) + + model = self.dummy_model() + sample = self.dummy_sample_deter * scheduler.init_noise_sigma + sample = sample.to(torch_device) + + for t in scheduler.timesteps: + sample = scheduler.scale_model_input(sample, t) + + model_output = model(sample, t) + + output = scheduler.step(model_output, t, sample, generator=generator) + sample = output.prev_sample + + result_sum = torch.sum(torch.abs(sample)) + result_mean = torch.mean(torch.abs(sample)) + + assert abs(result_sum.item() - 10.0807) < 1e-2 + assert abs(result_mean.item() - 0.0131) < 1e-3 + + +class EulerAncestralDiscreteSchedulerTest(SchedulerCommonTest): + scheduler_classes = (EulerAncestralDiscreteScheduler,) + num_inference_steps = 10 + + def get_scheduler_config(self, **kwargs): + config = { + "num_train_timesteps": 1100, + "beta_start": 0.0001, + "beta_end": 0.02, + "beta_schedule": "linear", + "trained_betas": None, + } + + config.update(**kwargs) + return config + + def test_timesteps(self): + for timesteps in [10, 50, 100, 1000]: + self.check_over_configs(num_train_timesteps=timesteps) + + def test_betas(self): + for beta_start, beta_end in zip([0.00001, 0.0001, 0.001], [0.0002, 0.002, 0.02]): + self.check_over_configs(beta_start=beta_start, beta_end=beta_end) + + def test_schedules(self): + for schedule in ["linear", "scaled_linear"]: + self.check_over_configs(beta_schedule=schedule) + + def test_full_loop_no_noise(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + scheduler.set_timesteps(self.num_inference_steps) + + if torch_device == "mps": + # device type MPS is not supported for torch.Generator() api. + generator = torch.manual_seed(0) + else: + generator = torch.Generator(device=torch_device).manual_seed(0) + + model = self.dummy_model() + sample = self.dummy_sample_deter * scheduler.init_noise_sigma + sample = sample.to(torch_device) + + for i, t in enumerate(scheduler.timesteps): + sample = scheduler.scale_model_input(sample, t) + + model_output = model(sample, t) + + output = scheduler.step(model_output, t, sample, generator=generator) + sample = output.prev_sample + + result_sum = torch.sum(torch.abs(sample)) + result_mean = torch.mean(torch.abs(sample)) + + if torch_device in ["cpu", "mps"]: + assert abs(result_sum.item() - 152.3192) < 1e-2 + assert abs(result_mean.item() - 0.1983) < 1e-3 + else: + # CUDA + assert abs(result_sum.item() - 144.8084) < 1e-2 + assert abs(result_mean.item() - 0.18855) < 1e-3 + + def test_full_loop_device(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + scheduler.set_timesteps(self.num_inference_steps, device=torch_device) + + if torch_device == "mps": + # device type MPS is not supported for torch.Generator() api. + generator = torch.manual_seed(0) + else: + generator = torch.Generator(device=torch_device).manual_seed(0) + + model = self.dummy_model() + sample = self.dummy_sample_deter * scheduler.init_noise_sigma + sample = sample.to(torch_device) + + for t in scheduler.timesteps: + sample = scheduler.scale_model_input(sample, t) + + model_output = model(sample, t) + + output = scheduler.step(model_output, t, sample, generator=generator) + sample = output.prev_sample + + result_sum = torch.sum(torch.abs(sample)) + result_mean = torch.mean(torch.abs(sample)) + + if str(torch_device).startswith("cpu"): + # The following sum varies between 148 and 156 on mps. Why? + assert abs(result_sum.item() - 152.3192) < 1e-2 + assert abs(result_mean.item() - 0.1983) < 1e-3 + elif str(torch_device).startswith("mps"): + # Larger tolerance on mps + assert abs(result_mean.item() - 0.1983) < 1e-2 + else: + # CUDA + assert abs(result_sum.item() - 144.8084) < 1e-2 + assert abs(result_mean.item() - 0.18855) < 1e-3 + + +class IPNDMSchedulerTest(SchedulerCommonTest): + scheduler_classes = (IPNDMScheduler,) + forward_default_kwargs = (("num_inference_steps", 50),) + + def get_scheduler_config(self, **kwargs): + config = {"num_train_timesteps": 1000} + config.update(**kwargs) + return config + + def check_over_configs(self, time_step=0, **config): + kwargs = dict(self.forward_default_kwargs) + num_inference_steps = kwargs.pop("num_inference_steps", None) + sample = self.dummy_sample + residual = 0.1 * sample + dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05] + + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config(**config) + scheduler = scheduler_class(**scheduler_config) + scheduler.set_timesteps(num_inference_steps) + # copy over dummy past residuals + scheduler.ets = dummy_past_residuals[:] + + if time_step is None: + time_step = scheduler.timesteps[len(scheduler.timesteps) // 2] + + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_config(tmpdirname) + new_scheduler = scheduler_class.from_pretrained(tmpdirname) + new_scheduler.set_timesteps(num_inference_steps) + # copy over dummy past residuals + new_scheduler.ets = dummy_past_residuals[:] + + output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample + new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample + + assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + + output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample + new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample + + assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + + def test_from_pretrained_save_pretrained(self): + pass + + def check_over_forward(self, time_step=0, **forward_kwargs): + kwargs = dict(self.forward_default_kwargs) + num_inference_steps = kwargs.pop("num_inference_steps", None) + sample = self.dummy_sample + residual = 0.1 * sample + dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05] + + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + scheduler.set_timesteps(num_inference_steps) + + # copy over dummy past residuals (must be after setting timesteps) + scheduler.ets = dummy_past_residuals[:] + + if time_step is None: + time_step = scheduler.timesteps[len(scheduler.timesteps) // 2] + + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_config(tmpdirname) + new_scheduler = scheduler_class.from_pretrained(tmpdirname) + # copy over dummy past residuals + new_scheduler.set_timesteps(num_inference_steps) + + # copy over dummy past residual (must be after setting timesteps) + new_scheduler.ets = dummy_past_residuals[:] + + output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample + new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample + + assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + + output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample + new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample + + assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + + def full_loop(self, **config): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config(**config) + scheduler = scheduler_class(**scheduler_config) + + num_inference_steps = 10 + model = self.dummy_model() + sample = self.dummy_sample_deter + scheduler.set_timesteps(num_inference_steps) + + for i, t in enumerate(scheduler.timesteps): + residual = model(sample, t) + sample = scheduler.step(residual, t, sample).prev_sample + + for i, t in enumerate(scheduler.timesteps): + residual = model(sample, t) + sample = scheduler.step(residual, t, sample).prev_sample + + return sample + + def test_step_shape(self): + kwargs = dict(self.forward_default_kwargs) + + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + sample = self.dummy_sample + residual = 0.1 * sample + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + scheduler.set_timesteps(num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + # copy over dummy past residuals (must be done after set_timesteps) + dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05] + scheduler.ets = dummy_past_residuals[:] + + time_step_0 = scheduler.timesteps[5] + time_step_1 = scheduler.timesteps[6] + + output_0 = scheduler.step(residual, time_step_0, sample, **kwargs).prev_sample + output_1 = scheduler.step(residual, time_step_1, sample, **kwargs).prev_sample + + self.assertEqual(output_0.shape, sample.shape) + self.assertEqual(output_0.shape, output_1.shape) + + output_0 = scheduler.step(residual, time_step_0, sample, **kwargs).prev_sample + output_1 = scheduler.step(residual, time_step_1, sample, **kwargs).prev_sample + + self.assertEqual(output_0.shape, sample.shape) + self.assertEqual(output_0.shape, output_1.shape) + + def test_timesteps(self): + for timesteps in [100, 1000]: + self.check_over_configs(num_train_timesteps=timesteps, time_step=None) + + def test_inference_steps(self): + for t, num_inference_steps in zip([1, 5, 10], [10, 50, 100]): + self.check_over_forward(num_inference_steps=num_inference_steps, time_step=None) + + def test_full_loop_no_noise(self): + sample = self.full_loop() + result_mean = torch.mean(torch.abs(sample)) + + assert abs(result_mean.item() - 2540529) < 10 + + +class VQDiffusionSchedulerTest(SchedulerCommonTest): + scheduler_classes = (VQDiffusionScheduler,) + + def get_scheduler_config(self, **kwargs): + config = { + "num_vec_classes": 4097, + "num_train_timesteps": 100, + } + + config.update(**kwargs) + return config + + def dummy_sample(self, num_vec_classes): + batch_size = 4 + height = 8 + width = 8 + + sample = torch.randint(0, num_vec_classes, (batch_size, height * width)) + + return sample + + @property + def dummy_sample_deter(self): + assert False + + def dummy_model(self, num_vec_classes): + def model(sample, t, *args): + batch_size, num_latent_pixels = sample.shape + logits = torch.rand((batch_size, num_vec_classes - 1, num_latent_pixels)) + return_value = F.log_softmax(logits.double(), dim=1).float() + return return_value + + return model + + def test_timesteps(self): + for timesteps in [2, 5, 100, 1000]: + self.check_over_configs(num_train_timesteps=timesteps) + + def test_num_vec_classes(self): + for num_vec_classes in [5, 100, 1000, 4000]: + self.check_over_configs(num_vec_classes=num_vec_classes) + + def test_time_indices(self): + for t in [0, 50, 99]: + self.check_over_forward(time_step=t) + + def test_add_noise_device(self): + pass diff --git a/tests/test_scheduler_flax.py b/tests/test_scheduler_flax.py new file mode 100644 index 000000000000..5ada689b724d --- /dev/null +++ b/tests/test_scheduler_flax.py @@ -0,0 +1,934 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +import tempfile +import unittest +from typing import Dict, List, Tuple + +from diffusers import FlaxDDIMScheduler, FlaxDDPMScheduler, FlaxPNDMScheduler +from diffusers.utils import deprecate, is_flax_available +from diffusers.utils.testing_utils import require_flax + + +if is_flax_available(): + import jax + import jax.numpy as jnp + from jax import random + + jax_device = jax.default_backend() + + +@require_flax +class FlaxSchedulerCommonTest(unittest.TestCase): + scheduler_classes = () + forward_default_kwargs = () + + @property + def dummy_sample(self): + batch_size = 4 + num_channels = 3 + height = 8 + width = 8 + + key1, key2 = random.split(random.PRNGKey(0)) + sample = random.uniform(key1, (batch_size, num_channels, height, width)) + + return sample, key2 + + @property + def dummy_sample_deter(self): + batch_size = 4 + num_channels = 3 + height = 8 + width = 8 + + num_elems = batch_size * num_channels * height * width + sample = jnp.arange(num_elems) + sample = sample.reshape(num_channels, height, width, batch_size) + sample = sample / num_elems + return jnp.transpose(sample, (3, 0, 1, 2)) + + def get_scheduler_config(self): + raise NotImplementedError + + def dummy_model(self): + def model(sample, t, *args): + return sample * t / (t + 1) + + return model + + def check_over_configs(self, time_step=0, **config): + kwargs = dict(self.forward_default_kwargs) + + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + sample, key = self.dummy_sample + residual = 0.1 * sample + + scheduler_config = self.get_scheduler_config(**config) + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_config(tmpdirname) + new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps) + new_state = new_scheduler.set_timesteps(new_state, num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + output = scheduler.step(state, residual, time_step, sample, key, **kwargs).prev_sample + new_output = new_scheduler.step(new_state, residual, time_step, sample, key, **kwargs).prev_sample + + assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + + def check_over_forward(self, time_step=0, **forward_kwargs): + kwargs = dict(self.forward_default_kwargs) + kwargs.update(forward_kwargs) + + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + sample, key = self.dummy_sample + residual = 0.1 * sample + + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_config(tmpdirname) + new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps) + new_state = new_scheduler.set_timesteps(new_state, num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + output = scheduler.step(state, residual, time_step, sample, key, **kwargs).prev_sample + new_output = new_scheduler.step(new_state, residual, time_step, sample, key, **kwargs).prev_sample + + assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + + def test_from_pretrained_save_pretrained(self): + kwargs = dict(self.forward_default_kwargs) + + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + sample, key = self.dummy_sample + residual = 0.1 * sample + + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_config(tmpdirname) + new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps) + new_state = new_scheduler.set_timesteps(new_state, num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + output = scheduler.step(state, residual, 1, sample, key, **kwargs).prev_sample + new_output = new_scheduler.step(new_state, residual, 1, sample, key, **kwargs).prev_sample + + assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + + def test_step_shape(self): + kwargs = dict(self.forward_default_kwargs) + + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + sample, key = self.dummy_sample + residual = 0.1 * sample + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + output_0 = scheduler.step(state, residual, 0, sample, key, **kwargs).prev_sample + output_1 = scheduler.step(state, residual, 1, sample, key, **kwargs).prev_sample + + self.assertEqual(output_0.shape, sample.shape) + self.assertEqual(output_0.shape, output_1.shape) + + def test_scheduler_outputs_equivalence(self): + def set_nan_tensor_to_zero(t): + return t.at[t != t].set(0) + + def recursive_check(tuple_object, dict_object): + if isinstance(tuple_object, (List, Tuple)): + for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()): + recursive_check(tuple_iterable_value, dict_iterable_value) + elif isinstance(tuple_object, Dict): + for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()): + recursive_check(tuple_iterable_value, dict_iterable_value) + elif tuple_object is None: + return + else: + self.assertTrue( + jnp.allclose(set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5), + msg=( + "Tuple and dict output are not equal. Difference:" + f" {jnp.max(jnp.abs(tuple_object - dict_object))}. Tuple has `nan`:" + f" {jnp.isnan(tuple_object).any()} and `inf`: {jnp.isinf(tuple_object)}. Dict has" + f" `nan`: {jnp.isnan(dict_object).any()} and `inf`: {jnp.isinf(dict_object)}." + ), + ) + + kwargs = dict(self.forward_default_kwargs) + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + sample, key = self.dummy_sample + residual = 0.1 * sample + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + outputs_dict = scheduler.step(state, residual, 0, sample, key, **kwargs) + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + outputs_tuple = scheduler.step(state, residual, 0, sample, key, return_dict=False, **kwargs) + + recursive_check(outputs_tuple[0], outputs_dict.prev_sample) + + def test_deprecated_kwargs(self): + for scheduler_class in self.scheduler_classes: + has_kwarg_in_model_class = "kwargs" in inspect.signature(scheduler_class.__init__).parameters + has_deprecated_kwarg = len(scheduler_class._deprecated_kwargs) > 0 + + if has_kwarg_in_model_class and not has_deprecated_kwarg: + raise ValueError( + f"{scheduler_class} has `**kwargs` in its __init__ method but has not defined any deprecated" + " kwargs under the `_deprecated_kwargs` class attribute. Make sure to either remove `**kwargs` if" + " there are no deprecated arguments or add the deprecated argument with `_deprecated_kwargs =" + " []`" + ) + + if not has_kwarg_in_model_class and has_deprecated_kwarg: + raise ValueError( + f"{scheduler_class} doesn't have `**kwargs` in its __init__ method but has defined deprecated" + " kwargs under the `_deprecated_kwargs` class attribute. Make sure to either add the `**kwargs`" + f" argument to {self.model_class}.__init__ if there are deprecated arguments or remove the" + " deprecated argument from `_deprecated_kwargs = []`" + ) + + +@require_flax +class FlaxDDPMSchedulerTest(FlaxSchedulerCommonTest): + scheduler_classes = (FlaxDDPMScheduler,) + + def get_scheduler_config(self, **kwargs): + config = { + "num_train_timesteps": 1000, + "beta_start": 0.0001, + "beta_end": 0.02, + "beta_schedule": "linear", + "variance_type": "fixed_small", + "clip_sample": True, + } + + config.update(**kwargs) + return config + + def test_timesteps(self): + for timesteps in [1, 5, 100, 1000]: + self.check_over_configs(num_train_timesteps=timesteps) + + def test_betas(self): + for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]): + self.check_over_configs(beta_start=beta_start, beta_end=beta_end) + + def test_schedules(self): + for schedule in ["linear", "squaredcos_cap_v2"]: + self.check_over_configs(beta_schedule=schedule) + + def test_variance_type(self): + for variance in ["fixed_small", "fixed_large", "other"]: + self.check_over_configs(variance_type=variance) + + def test_clip_sample(self): + for clip_sample in [True, False]: + self.check_over_configs(clip_sample=clip_sample) + + def test_time_indices(self): + for t in [0, 500, 999]: + self.check_over_forward(time_step=t) + + def test_variance(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + assert jnp.sum(jnp.abs(scheduler._get_variance(0) - 0.0)) < 1e-5 + assert jnp.sum(jnp.abs(scheduler._get_variance(487) - 0.00979)) < 1e-5 + assert jnp.sum(jnp.abs(scheduler._get_variance(999) - 0.02)) < 1e-5 + + def test_full_loop_no_noise(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + num_trained_timesteps = len(scheduler) + + model = self.dummy_model() + sample = self.dummy_sample_deter + key1, key2 = random.split(random.PRNGKey(0)) + + for t in reversed(range(num_trained_timesteps)): + # 1. predict noise residual + residual = model(sample, t) + + # 2. predict previous mean of sample x_t-1 + output = scheduler.step(state, residual, t, sample, key1) + pred_prev_sample = output.prev_sample + state = output.state + key1, key2 = random.split(key2) + + # if t > 0: + # noise = self.dummy_sample_deter + # variance = scheduler.get_variance(t) ** (0.5) * noise + # + # sample = pred_prev_sample + variance + sample = pred_prev_sample + + result_sum = jnp.sum(jnp.abs(sample)) + result_mean = jnp.mean(jnp.abs(sample)) + + if jax_device == "tpu": + assert abs(result_sum - 255.0714) < 1e-2 + assert abs(result_mean - 0.332124) < 1e-3 + else: + assert abs(result_sum - 255.1113) < 1e-2 + assert abs(result_mean - 0.332176) < 1e-3 + + +@require_flax +class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest): + scheduler_classes = (FlaxDDIMScheduler,) + forward_default_kwargs = (("num_inference_steps", 50),) + + def get_scheduler_config(self, **kwargs): + config = { + "num_train_timesteps": 1000, + "beta_start": 0.0001, + "beta_end": 0.02, + "beta_schedule": "linear", + } + + config.update(**kwargs) + return config + + def full_loop(self, **config): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config(**config) + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + key1, key2 = random.split(random.PRNGKey(0)) + + num_inference_steps = 10 + + model = self.dummy_model() + sample = self.dummy_sample_deter + + state = scheduler.set_timesteps(state, num_inference_steps) + + for t in state.timesteps: + residual = model(sample, t) + output = scheduler.step(state, residual, t, sample) + sample = output.prev_sample + state = output.state + key1, key2 = random.split(key2) + + return sample + + def check_over_configs(self, time_step=0, **config): + kwargs = dict(self.forward_default_kwargs) + + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + sample, _ = self.dummy_sample + residual = 0.1 * sample + + scheduler_config = self.get_scheduler_config(**config) + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_config(tmpdirname) + new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps) + new_state = new_scheduler.set_timesteps(new_state, num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + output = scheduler.step(state, residual, time_step, sample, **kwargs).prev_sample + new_output = new_scheduler.step(new_state, residual, time_step, sample, **kwargs).prev_sample + + assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + + def test_from_pretrained_save_pretrained(self): + kwargs = dict(self.forward_default_kwargs) + + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + sample, _ = self.dummy_sample + residual = 0.1 * sample + + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_config(tmpdirname) + new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps) + new_state = new_scheduler.set_timesteps(new_state, num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + output = scheduler.step(state, residual, 1, sample, **kwargs).prev_sample + new_output = new_scheduler.step(new_state, residual, 1, sample, **kwargs).prev_sample + + assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + + def check_over_forward(self, time_step=0, **forward_kwargs): + kwargs = dict(self.forward_default_kwargs) + kwargs.update(forward_kwargs) + + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + sample, _ = self.dummy_sample + residual = 0.1 * sample + + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_config(tmpdirname) + new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps) + new_state = new_scheduler.set_timesteps(new_state, num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + output = scheduler.step(state, residual, time_step, sample, **kwargs).prev_sample + new_output = new_scheduler.step(new_state, residual, time_step, sample, **kwargs).prev_sample + + assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + + def test_scheduler_outputs_equivalence(self): + def set_nan_tensor_to_zero(t): + return t.at[t != t].set(0) + + def recursive_check(tuple_object, dict_object): + if isinstance(tuple_object, (List, Tuple)): + for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()): + recursive_check(tuple_iterable_value, dict_iterable_value) + elif isinstance(tuple_object, Dict): + for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()): + recursive_check(tuple_iterable_value, dict_iterable_value) + elif tuple_object is None: + return + else: + self.assertTrue( + jnp.allclose(set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5), + msg=( + "Tuple and dict output are not equal. Difference:" + f" {jnp.max(jnp.abs(tuple_object - dict_object))}. Tuple has `nan`:" + f" {jnp.isnan(tuple_object).any()} and `inf`: {jnp.isinf(tuple_object)}. Dict has" + f" `nan`: {jnp.isnan(dict_object).any()} and `inf`: {jnp.isinf(dict_object)}." + ), + ) + + kwargs = dict(self.forward_default_kwargs) + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + sample, _ = self.dummy_sample + residual = 0.1 * sample + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + outputs_dict = scheduler.step(state, residual, 0, sample, **kwargs) + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + outputs_tuple = scheduler.step(state, residual, 0, sample, return_dict=False, **kwargs) + + recursive_check(outputs_tuple[0], outputs_dict.prev_sample) + + def test_step_shape(self): + kwargs = dict(self.forward_default_kwargs) + + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + sample, _ = self.dummy_sample + residual = 0.1 * sample + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + output_0 = scheduler.step(state, residual, 0, sample, **kwargs).prev_sample + output_1 = scheduler.step(state, residual, 1, sample, **kwargs).prev_sample + + self.assertEqual(output_0.shape, sample.shape) + self.assertEqual(output_0.shape, output_1.shape) + + def test_timesteps(self): + for timesteps in [100, 500, 1000]: + self.check_over_configs(num_train_timesteps=timesteps) + + def test_steps_offset(self): + for steps_offset in [0, 1]: + self.check_over_configs(steps_offset=steps_offset) + + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config(steps_offset=1) + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + state = scheduler.set_timesteps(state, 5) + assert jnp.equal(state.timesteps, jnp.array([801, 601, 401, 201, 1])).all() + + def test_betas(self): + for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]): + self.check_over_configs(beta_start=beta_start, beta_end=beta_end) + + def test_schedules(self): + for schedule in ["linear", "squaredcos_cap_v2"]: + self.check_over_configs(beta_schedule=schedule) + + def test_time_indices(self): + for t in [1, 10, 49]: + self.check_over_forward(time_step=t) + + def test_inference_steps(self): + for t, num_inference_steps in zip([1, 10, 50], [10, 50, 500]): + self.check_over_forward(time_step=t, num_inference_steps=num_inference_steps) + + def test_variance(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + assert jnp.sum(jnp.abs(scheduler._get_variance(0, 0, state.alphas_cumprod) - 0.0)) < 1e-5 + assert jnp.sum(jnp.abs(scheduler._get_variance(420, 400, state.alphas_cumprod) - 0.14771)) < 1e-5 + assert jnp.sum(jnp.abs(scheduler._get_variance(980, 960, state.alphas_cumprod) - 0.32460)) < 1e-5 + assert jnp.sum(jnp.abs(scheduler._get_variance(0, 0, state.alphas_cumprod) - 0.0)) < 1e-5 + assert jnp.sum(jnp.abs(scheduler._get_variance(487, 486, state.alphas_cumprod) - 0.00979)) < 1e-5 + assert jnp.sum(jnp.abs(scheduler._get_variance(999, 998, state.alphas_cumprod) - 0.02)) < 1e-5 + + def test_full_loop_no_noise(self): + sample = self.full_loop() + + result_sum = jnp.sum(jnp.abs(sample)) + result_mean = jnp.mean(jnp.abs(sample)) + + assert abs(result_sum - 172.0067) < 1e-2 + assert abs(result_mean - 0.223967) < 1e-3 + + def test_full_loop_with_set_alpha_to_one(self): + # We specify different beta, so that the first alpha is 0.99 + sample = self.full_loop(set_alpha_to_one=True, beta_start=0.01) + result_sum = jnp.sum(jnp.abs(sample)) + result_mean = jnp.mean(jnp.abs(sample)) + + if jax_device == "tpu": + assert abs(result_sum - 149.8409) < 1e-2 + assert abs(result_mean - 0.1951) < 1e-3 + else: + assert abs(result_sum - 149.8295) < 1e-2 + assert abs(result_mean - 0.1951) < 1e-3 + + def test_full_loop_with_no_set_alpha_to_one(self): + # We specify different beta, so that the first alpha is 0.99 + sample = self.full_loop(set_alpha_to_one=False, beta_start=0.01) + result_sum = jnp.sum(jnp.abs(sample)) + result_mean = jnp.mean(jnp.abs(sample)) + + if jax_device == "tpu": + pass + # FIXME: both result_sum and result_mean are nan on TPU + # assert jnp.isnan(result_sum) + # assert jnp.isnan(result_mean) + else: + assert abs(result_sum - 149.0784) < 1e-2 + assert abs(result_mean - 0.1941) < 1e-3 + + def test_prediction_type(self): + for prediction_type in ["epsilon", "sample", "v_prediction"]: + self.check_over_configs(prediction_type=prediction_type) + + def test_deprecated_predict_epsilon(self): + deprecate("remove this test", "0.10.0", "remove") + for predict_epsilon in [True, False]: + self.check_over_configs(predict_epsilon=predict_epsilon) + + def test_deprecated_predict_epsilon_to_prediction_type(self): + deprecate("remove this test", "0.10.0", "remove") + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config(predict_epsilon=True) + scheduler = scheduler_class.from_config(scheduler_config) + assert scheduler.prediction_type == "epsilon" + + scheduler_config = self.get_scheduler_config(predict_epsilon=False) + scheduler = scheduler_class.from_config(scheduler_config) + assert scheduler.prediction_type == "sample" + + +@require_flax +class FlaxPNDMSchedulerTest(FlaxSchedulerCommonTest): + scheduler_classes = (FlaxPNDMScheduler,) + forward_default_kwargs = (("num_inference_steps", 50),) + + def get_scheduler_config(self, **kwargs): + config = { + "num_train_timesteps": 1000, + "beta_start": 0.0001, + "beta_end": 0.02, + "beta_schedule": "linear", + } + + config.update(**kwargs) + return config + + def check_over_configs(self, time_step=0, **config): + kwargs = dict(self.forward_default_kwargs) + num_inference_steps = kwargs.pop("num_inference_steps", None) + sample, _ = self.dummy_sample + residual = 0.1 * sample + dummy_past_residuals = jnp.array([residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]) + + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config(**config) + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape) + # copy over dummy past residuals + state = state.replace(ets=dummy_past_residuals[:]) + + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_config(tmpdirname) + new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) + new_state = new_scheduler.set_timesteps(new_state, num_inference_steps, shape=sample.shape) + # copy over dummy past residuals + new_state = new_state.replace(ets=dummy_past_residuals[:]) + + (prev_sample, state) = scheduler.step_prk(state, residual, time_step, sample, **kwargs) + (new_prev_sample, new_state) = new_scheduler.step_prk(new_state, residual, time_step, sample, **kwargs) + + assert jnp.sum(jnp.abs(prev_sample - new_prev_sample)) < 1e-5, "Scheduler outputs are not identical" + + output, _ = scheduler.step_plms(state, residual, time_step, sample, **kwargs) + new_output, _ = new_scheduler.step_plms(new_state, residual, time_step, sample, **kwargs) + + assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + + def test_from_pretrained_save_pretrained(self): + pass + + def test_scheduler_outputs_equivalence(self): + def set_nan_tensor_to_zero(t): + return t.at[t != t].set(0) + + def recursive_check(tuple_object, dict_object): + if isinstance(tuple_object, (List, Tuple)): + for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()): + recursive_check(tuple_iterable_value, dict_iterable_value) + elif isinstance(tuple_object, Dict): + for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()): + recursive_check(tuple_iterable_value, dict_iterable_value) + elif tuple_object is None: + return + else: + self.assertTrue( + jnp.allclose(set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5), + msg=( + "Tuple and dict output are not equal. Difference:" + f" {jnp.max(jnp.abs(tuple_object - dict_object))}. Tuple has `nan`:" + f" {jnp.isnan(tuple_object).any()} and `inf`: {jnp.isinf(tuple_object)}. Dict has" + f" `nan`: {jnp.isnan(dict_object).any()} and `inf`: {jnp.isinf(dict_object)}." + ), + ) + + kwargs = dict(self.forward_default_kwargs) + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + sample, _ = self.dummy_sample + residual = 0.1 * sample + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + outputs_dict = scheduler.step(state, residual, 0, sample, **kwargs) + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + outputs_tuple = scheduler.step(state, residual, 0, sample, return_dict=False, **kwargs) + + recursive_check(outputs_tuple[0], outputs_dict.prev_sample) + + def check_over_forward(self, time_step=0, **forward_kwargs): + kwargs = dict(self.forward_default_kwargs) + num_inference_steps = kwargs.pop("num_inference_steps", None) + sample, _ = self.dummy_sample + residual = 0.1 * sample + dummy_past_residuals = jnp.array([residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]) + + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape) + + # copy over dummy past residuals (must be after setting timesteps) + scheduler.ets = dummy_past_residuals[:] + + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_config(tmpdirname) + new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) + # copy over dummy past residuals + new_state = new_scheduler.set_timesteps(new_state, num_inference_steps, shape=sample.shape) + + # copy over dummy past residual (must be after setting timesteps) + new_state.replace(ets=dummy_past_residuals[:]) + + output, state = scheduler.step_prk(state, residual, time_step, sample, **kwargs) + new_output, new_state = new_scheduler.step_prk(new_state, residual, time_step, sample, **kwargs) + + assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + + output, _ = scheduler.step_plms(state, residual, time_step, sample, **kwargs) + new_output, _ = new_scheduler.step_plms(new_state, residual, time_step, sample, **kwargs) + + assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + + def full_loop(self, **config): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config(**config) + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + num_inference_steps = 10 + model = self.dummy_model() + sample = self.dummy_sample_deter + state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape) + + for i, t in enumerate(state.prk_timesteps): + residual = model(sample, t) + sample, state = scheduler.step_prk(state, residual, t, sample) + + for i, t in enumerate(state.plms_timesteps): + residual = model(sample, t) + sample, state = scheduler.step_plms(state, residual, t, sample) + + return sample + + def test_step_shape(self): + kwargs = dict(self.forward_default_kwargs) + + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + sample, _ = self.dummy_sample + residual = 0.1 * sample + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + # copy over dummy past residuals (must be done after set_timesteps) + dummy_past_residuals = jnp.array([residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]) + state = state.replace(ets=dummy_past_residuals[:]) + + output_0, state = scheduler.step_prk(state, residual, 0, sample, **kwargs) + output_1, state = scheduler.step_prk(state, residual, 1, sample, **kwargs) + + self.assertEqual(output_0.shape, sample.shape) + self.assertEqual(output_0.shape, output_1.shape) + + output_0, state = scheduler.step_plms(state, residual, 0, sample, **kwargs) + output_1, state = scheduler.step_plms(state, residual, 1, sample, **kwargs) + + self.assertEqual(output_0.shape, sample.shape) + self.assertEqual(output_0.shape, output_1.shape) + + def test_timesteps(self): + for timesteps in [100, 1000]: + self.check_over_configs(num_train_timesteps=timesteps) + + def test_steps_offset(self): + for steps_offset in [0, 1]: + self.check_over_configs(steps_offset=steps_offset) + + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config(steps_offset=1) + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + state = scheduler.set_timesteps(state, 10, shape=()) + assert jnp.equal( + state.timesteps, + jnp.array([901, 851, 851, 801, 801, 751, 751, 701, 701, 651, 651, 601, 601, 501, 401, 301, 201, 101, 1]), + ).all() + + def test_betas(self): + for beta_start, beta_end in zip([0.0001, 0.001], [0.002, 0.02]): + self.check_over_configs(beta_start=beta_start, beta_end=beta_end) + + def test_schedules(self): + for schedule in ["linear", "squaredcos_cap_v2"]: + self.check_over_configs(beta_schedule=schedule) + + def test_time_indices(self): + for t in [1, 5, 10]: + self.check_over_forward(time_step=t) + + def test_inference_steps(self): + for t, num_inference_steps in zip([1, 5, 10], [10, 50, 100]): + self.check_over_forward(num_inference_steps=num_inference_steps) + + def test_pow_of_3_inference_steps(self): + # earlier version of set_timesteps() caused an error indexing alpha's with inference steps as power of 3 + num_inference_steps = 27 + + for scheduler_class in self.scheduler_classes: + sample, _ = self.dummy_sample + residual = 0.1 * sample + + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape) + + # before power of 3 fix, would error on first step, so we only need to do two + for i, t in enumerate(state.prk_timesteps[:2]): + sample, state = scheduler.step_prk(state, residual, t, sample) + + def test_inference_plms_no_past_residuals(self): + with self.assertRaises(ValueError): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + scheduler.step_plms(state, self.dummy_sample, 1, self.dummy_sample).prev_sample + + def test_full_loop_no_noise(self): + sample = self.full_loop() + result_sum = jnp.sum(jnp.abs(sample)) + result_mean = jnp.mean(jnp.abs(sample)) + + if jax_device == "tpu": + assert abs(result_sum - 198.1275) < 1e-2 + assert abs(result_mean - 0.2580) < 1e-3 + else: + assert abs(result_sum - 198.1318) < 1e-2 + assert abs(result_mean - 0.2580) < 1e-3 + + def test_full_loop_with_set_alpha_to_one(self): + # We specify different beta, so that the first alpha is 0.99 + sample = self.full_loop(set_alpha_to_one=True, beta_start=0.01) + result_sum = jnp.sum(jnp.abs(sample)) + result_mean = jnp.mean(jnp.abs(sample)) + + if jax_device == "tpu": + assert abs(result_sum - 186.83226) < 1e-2 + assert abs(result_mean - 0.24327) < 1e-3 + else: + assert abs(result_sum - 186.9466) < 1e-2 + assert abs(result_mean - 0.24342) < 1e-3 + + def test_full_loop_with_no_set_alpha_to_one(self): + # We specify different beta, so that the first alpha is 0.99 + sample = self.full_loop(set_alpha_to_one=False, beta_start=0.01) + result_sum = jnp.sum(jnp.abs(sample)) + result_mean = jnp.mean(jnp.abs(sample)) + + if jax_device == "tpu": + assert abs(result_sum - 186.83226) < 1e-2 + assert abs(result_mean - 0.24327) < 1e-3 + else: + assert abs(result_sum - 186.9482) < 1e-2 + assert abs(result_mean - 0.2434) < 1e-3 diff --git a/tests/test_utils.py b/tests/test_utils.py index 35cf57421014..761242eb9ac4 100755 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -26,7 +26,7 @@ class DeprecateTester(unittest.TestCase): def test_deprecate_function_arg(self): kwargs = {"deprecated_arg": 4} - with self.assertWarns(DeprecationWarning) as warning: + with self.assertWarns(FutureWarning) as warning: output = deprecate("deprecated_arg", self.higher_version, "message", take_from=kwargs) assert output == 4 @@ -39,7 +39,7 @@ def test_deprecate_function_arg(self): def test_deprecate_function_arg_tuple(self): kwargs = {"deprecated_arg": 4} - with self.assertWarns(DeprecationWarning) as warning: + with self.assertWarns(FutureWarning) as warning: output = deprecate(("deprecated_arg", self.higher_version, "message"), take_from=kwargs) assert output == 4 @@ -51,7 +51,7 @@ def test_deprecate_function_arg_tuple(self): def test_deprecate_function_args(self): kwargs = {"deprecated_arg_1": 4, "deprecated_arg_2": 8} - with self.assertWarns(DeprecationWarning) as warning: + with self.assertWarns(FutureWarning) as warning: output_1, output_2 = deprecate( ("deprecated_arg_1", self.higher_version, "Hey"), ("deprecated_arg_2", self.higher_version, "Hey"), @@ -81,7 +81,7 @@ def test_deprecate_function_incorrect_arg(self): assert "got an unexpected keyword argument `deprecated_arg`" in str(error.exception) def test_deprecate_arg_no_kwarg(self): - with self.assertWarns(DeprecationWarning) as warning: + with self.assertWarns(FutureWarning) as warning: deprecate(("deprecated_arg", self.higher_version, "message")) assert ( @@ -90,7 +90,7 @@ def test_deprecate_arg_no_kwarg(self): ) def test_deprecate_args_no_kwarg(self): - with self.assertWarns(DeprecationWarning) as warning: + with self.assertWarns(FutureWarning) as warning: deprecate( ("deprecated_arg_1", self.higher_version, "Hey"), ("deprecated_arg_2", self.higher_version, "Hey"), @@ -108,7 +108,7 @@ def test_deprecate_class_obj(self): class Args: arg = 5 - with self.assertWarns(DeprecationWarning) as warning: + with self.assertWarns(FutureWarning) as warning: arg = deprecate(("arg", self.higher_version, "message"), take_from=Args()) assert arg == 5 @@ -122,7 +122,7 @@ class Args: arg = 5 foo = 7 - with self.assertWarns(DeprecationWarning) as warning: + with self.assertWarns(FutureWarning) as warning: arg_1, arg_2 = deprecate( ("arg", self.higher_version, "message"), ("foo", self.higher_version, "message"), @@ -158,7 +158,7 @@ def test_deprecate_incorrect_version(self): ) def test_deprecate_incorrect_no_standard_warn(self): - with self.assertWarns(DeprecationWarning) as warning: + with self.assertWarns(FutureWarning) as warning: deprecate(("deprecated_arg", self.higher_version, "This message is better!!!"), standard_warn=False) assert str(warning.warning) == "This message is better!!!" diff --git a/utils/check_copies.py b/utils/check_copies.py index 50f02cac658d..16782397da74 100644 --- a/utils/check_copies.py +++ b/utils/check_copies.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2020 The HuggingFace Inc. team. +# Copyright 2022 The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,6 +15,7 @@ import argparse import glob +import importlib.util import os import re @@ -24,52 +25,17 @@ # All paths are set with the intent you should run this script from the root of the repo with the command # python utils/check_copies.py -TRANSFORMERS_PATH = "src/diffusers" -PATH_TO_DOCS = "docs/source/en" +DIFFUSERS_PATH = "src/diffusers" REPO_PATH = "." -# Mapping for files that are full copies of others (keys are copies, values the file to keep them up to data with) -FULL_COPIES = { - "examples/tensorflow/question-answering/utils_qa.py": "examples/pytorch/question-answering/utils_qa.py", - "examples/flax/question-answering/utils_qa.py": "examples/pytorch/question-answering/utils_qa.py", -} - - -LOCALIZED_READMES = { - # If the introduction or the conclusion of the list change, the prompts may need to be updated. - "README.md": { - "start_prompt": "🤗 Transformers currently provides the following architectures", - "end_prompt": "1. Want to contribute a new model?", - "format_model_list": ( - "**[{title}]({model_link})** (from {paper_affiliations}) released with the paper {paper_title_link} by" - " {paper_authors}.{supplements}" - ), - }, - "README_zh-hans.md": { - "start_prompt": "🤗 Transformers 目前支持如下的架构", - "end_prompt": "1. 想要贡献新的模型?", - "format_model_list": ( - "**[{title}]({model_link})** (来自 {paper_affiliations}) 伴随论文 {paper_title_link} 由 {paper_authors}" - " 发布。{supplements}" - ), - }, - "README_zh-hant.md": { - "start_prompt": "🤗 Transformers 目前支援以下的架構", - "end_prompt": "1. 想要貢獻新的模型?", - "format_model_list": ( - "**[{title}]({model_link})** (from {paper_affiliations}) released with the paper {paper_title_link} by" - " {paper_authors}.{supplements}" - ), - }, - "README_ko.md": { - "start_prompt": "🤗 Transformers는 다음 모델들을 제공합니다", - "end_prompt": "1. 새로운 모델을 올리고 싶나요?", - "format_model_list": ( - "**[{title}]({model_link})** (from {paper_affiliations}) released with the paper {paper_title_link} by" - " {paper_authors}.{supplements}" - ), - }, -} + +# This is to make sure the diffusers module imported is the one in the repo. +spec = importlib.util.spec_from_file_location( + "diffusers", + os.path.join(DIFFUSERS_PATH, "__init__.py"), + submodule_search_locations=[DIFFUSERS_PATH], +) +diffusers_module = spec.loader.load_module() def _should_continue(line, indent): @@ -83,14 +49,14 @@ def find_code_in_diffusers(object_name): # First let's find the module where our object lives. module = parts[i] - while i < len(parts) and not os.path.isfile(os.path.join(TRANSFORMERS_PATH, f"{module}.py")): + while i < len(parts) and not os.path.isfile(os.path.join(DIFFUSERS_PATH, f"{module}.py")): i += 1 if i < len(parts): module = os.path.join(module, parts[i]) if i >= len(parts): raise ValueError(f"`object_name` should begin with the name of a module of diffusers but got {object_name}.") - with open(os.path.join(TRANSFORMERS_PATH, f"{module}.py"), "r", encoding="utf-8", newline="\n") as f: + with open(os.path.join(DIFFUSERS_PATH, f"{module}.py"), "r", encoding="utf-8", newline="\n") as f: lines = f.readlines() # Now let's find the class / func in the code! @@ -121,6 +87,7 @@ def find_code_in_diffusers(object_name): _re_copy_warning = re.compile(r"^(\s*)#\s*Copied from\s+diffusers\.(\S+\.\S+)\s*($|\S.*$)") _re_replace_pattern = re.compile(r"^\s*(\S+)->(\S+)(\s+.*|$)") +_re_fill_pattern = re.compile(r"]*>") def get_indent(code): @@ -140,7 +107,7 @@ def blackify(code): has_indent = len(get_indent(code)) > 0 if has_indent: code = f"class Bla:\n{code}" - mode = black.Mode(target_versions={black.TargetVersion.PY35}, line_length=119, preview=True) + mode = black.Mode(target_versions={black.TargetVersion.PY37}, line_length=119, preview=True) result = black.format_str(code, mode=mode) result, _ = style_docstrings_in_code(result) return result[len("class Bla:\n") :] if has_indent else result @@ -149,7 +116,6 @@ def blackify(code): def is_copy_consistent(filename, overwrite=False): """ Check if the code commented as a copy in `filename` matches the original. - Return the differences or overwrites the content depending on `overwrite`. """ with open(filename, "r", encoding="utf-8", newline="\n") as f: @@ -187,6 +153,10 @@ def is_copy_consistent(filename, overwrite=False): observed_code_lines = lines[start_index:line_index] observed_code = "".join(observed_code_lines) + # Remove any nested `Copied from` comments to avoid circular copies + theoretical_code = [line for line in theoretical_code.split("\n") if _re_copy_warning.search(line) is None] + theoretical_code = "\n".join(theoretical_code) + # Before comparing, use the `replace_pattern` on the original code. if len(replace_pattern) > 0: patterns = replace_pattern.replace("with", "").split(",") @@ -221,7 +191,7 @@ def is_copy_consistent(filename, overwrite=False): def check_copies(overwrite: bool = False): - all_files = glob.glob(os.path.join(TRANSFORMERS_PATH, "**/*.py"), recursive=True) + all_files = glob.glob(os.path.join(DIFFUSERS_PATH, "**/*.py"), recursive=True) diffs = [] for filename in all_files: new_diffs = is_copy_consistent(filename, overwrite) @@ -235,224 +205,9 @@ def check_copies(overwrite: bool = False): ) -# check_model_list_copy(overwrite=overwrite) - - -def check_full_copies(overwrite: bool = False): - diffs = [] - for target, source in FULL_COPIES.items(): - with open(source, "r", encoding="utf-8") as f: - source_code = f.read() - with open(target, "r", encoding="utf-8") as f: - target_code = f.read() - if source_code != target_code: - if overwrite: - with open(target, "w", encoding="utf-8") as f: - print(f"Replacing the content of {target} by the one of {source}.") - f.write(source_code) - else: - diffs.append(f"- {target}: copy does not match {source}.") - - if not overwrite and len(diffs) > 0: - diff = "\n".join(diffs) - raise Exception( - "Found the following copy inconsistencies:\n" - + diff - + "\nRun `make fix-copies` or `python utils/check_copies.py --fix_and_overwrite` to fix them." - ) - - -def get_model_list(filename, start_prompt, end_prompt): - """Extracts the model list from the README.""" - with open(os.path.join(REPO_PATH, filename), "r", encoding="utf-8", newline="\n") as f: - lines = f.readlines() - # Find the start of the list. - start_index = 0 - while not lines[start_index].startswith(start_prompt): - start_index += 1 - start_index += 1 - - result = [] - current_line = "" - end_index = start_index - - while not lines[end_index].startswith(end_prompt): - if lines[end_index].startswith("1."): - if len(current_line) > 1: - result.append(current_line) - current_line = lines[end_index] - elif len(lines[end_index]) > 1: - current_line = f"{current_line[:-1]} {lines[end_index].lstrip()}" - end_index += 1 - if len(current_line) > 1: - result.append(current_line) - - return "".join(result) - - -def convert_to_localized_md(model_list, localized_model_list, format_str): - """Convert `model_list` to each localized README.""" - - def _rep(match): - title, model_link, paper_affiliations, paper_title_link, paper_authors, supplements = match.groups() - return format_str.format( - title=title, - model_link=model_link, - paper_affiliations=paper_affiliations, - paper_title_link=paper_title_link, - paper_authors=paper_authors, - supplements=" " + supplements.strip() if len(supplements) != 0 else "", - ) - - # This regex captures metadata from an English model description, including model title, model link, - # affiliations of the paper, title of the paper, authors of the paper, and supplemental data (see DistilBERT for example). - _re_capture_meta = re.compile( - r"\*\*\[([^\]]*)\]\(([^\)]*)\)\*\* \(from ([^)]*)\)[^\[]*([^\)]*\)).*?by (.*?[A-Za-z\*]{2,}?)\. (.*)$" - ) - # This regex is used to synchronize link. - _re_capture_title_link = re.compile(r"\*\*\[([^\]]*)\]\(([^\)]*)\)\*\*") - - if len(localized_model_list) == 0: - localized_model_index = {} - else: - try: - localized_model_index = { - re.search(r"\*\*\[([^\]]*)", line).groups()[0]: line - for line in localized_model_list.strip().split("\n") - } - except AttributeError: - raise AttributeError("A model name in localized READMEs cannot be recognized.") - - model_keys = [re.search(r"\*\*\[([^\]]*)", line).groups()[0] for line in model_list.strip().split("\n")] - - # We exclude keys in localized README not in the main one. - readmes_match = not any([k not in model_keys for k in localized_model_index]) - localized_model_index = {k: v for k, v in localized_model_index.items() if k in model_keys} - - for model in model_list.strip().split("\n"): - title, model_link = _re_capture_title_link.search(model).groups() - if title not in localized_model_index: - readmes_match = False - # Add an anchor white space behind a model description string for regex. - # If metadata cannot be captured, the English version will be directly copied. - localized_model_index[title] = _re_capture_meta.sub(_rep, model + " ") - else: - # Synchronize link - localized_model_index[title] = _re_capture_title_link.sub( - f"**[{title}]({model_link})**", localized_model_index[title], count=1 - ) - - sorted_index = sorted(localized_model_index.items(), key=lambda x: x[0].lower()) - - return readmes_match, "\n".join(map(lambda x: x[1], sorted_index)) + "\n" - - -def convert_readme_to_index(model_list): - model_list = model_list.replace("https://huggingface.co/docs/diffusers/main/", "") - return model_list.replace("https://huggingface.co/docs/diffusers/", "") - - -def _find_text_in_file(filename, start_prompt, end_prompt): - """ - Find the text in `filename` between a line beginning with `start_prompt` and before `end_prompt`, removing empty - lines. - """ - with open(filename, "r", encoding="utf-8", newline="\n") as f: - lines = f.readlines() - # Find the start prompt. - start_index = 0 - while not lines[start_index].startswith(start_prompt): - start_index += 1 - start_index += 1 - - end_index = start_index - while not lines[end_index].startswith(end_prompt): - end_index += 1 - end_index -= 1 - - while len(lines[start_index]) <= 1: - start_index += 1 - while len(lines[end_index]) <= 1: - end_index -= 1 - end_index += 1 - return "".join(lines[start_index:end_index]), start_index, end_index, lines - - -def check_model_list_copy(overwrite=False, max_per_line=119): - """Check the model lists in the README and index.rst are consistent and maybe `overwrite`.""" - # Fix potential doc links in the README - with open(os.path.join(REPO_PATH, "README.md"), "r", encoding="utf-8", newline="\n") as f: - readme = f.read() - new_readme = readme.replace("https://huggingface.co/diffusers", "https://huggingface.co/docs/diffusers") - new_readme = new_readme.replace( - "https://huggingface.co/docs/main/diffusers", "https://huggingface.co/docs/diffusers/main" - ) - if new_readme != readme: - if overwrite: - with open(os.path.join(REPO_PATH, "README.md"), "w", encoding="utf-8", newline="\n") as f: - f.write(new_readme) - else: - raise ValueError( - "The main README contains wrong links to the documentation of Transformers. Run `make fix-copies` to " - "automatically fix them." - ) - - # If the introduction or the conclusion of the list change, the prompts may need to be updated. - index_list, start_index, end_index, lines = _find_text_in_file( - filename=os.path.join(PATH_TO_DOCS, "index.mdx"), - start_prompt="