Skip to content

Commit

Permalink
Move from flux to SDXL
Browse files Browse the repository at this point in the history
I tried. I really tried. I'm not skilled enough in fighting the
diffusers library to make Flux use less than 49 GB of vram. I tried
everything that was documented to work and found out it was all
broken in any number of ways that made me feel like I was being gaslit
hard.

It's not as flashy to use SDXL here, but I needed to get _something_
working, so I fell back to what I know will work.
  • Loading branch information
Xe committed Oct 24, 2024
1 parent 5ce63e9 commit 02c5148
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 63 deletions.
3 changes: 2 additions & 1 deletion Brewfile
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ brew "python@3.11"
brew "pipenv"
brew "cog"
brew "huggingface-cli"
brew "awscli"
brew "awscli"
brew "jq"
75 changes: 62 additions & 13 deletions Brewfile.lock.json
Original file line number Diff line number Diff line change
Expand Up @@ -168,40 +168,89 @@
}
},
"awscli": {
"version": "2.18.10",
"version": "2.18.11",
"bottle": {
"rebuild": 0,
"root_url": "https://ghcr.io/v2/homebrew/core",
"files": {
"arm64_sequoia": {
"cellar": ":any_skip_relocation",
"url": "https://ghcr.io/v2/homebrew/core/awscli/blobs/sha256:6d4e4c816926b70f0eaa513ad13b64a22ec89fbe2c59f7b3930675cf234c5d9d",
"sha256": "6d4e4c816926b70f0eaa513ad13b64a22ec89fbe2c59f7b3930675cf234c5d9d"
"url": "https://ghcr.io/v2/homebrew/core/awscli/blobs/sha256:4e8870ebb1d7c025d942f7d22963c4a964af0191405f82384d35bbd2be87ce37",
"sha256": "4e8870ebb1d7c025d942f7d22963c4a964af0191405f82384d35bbd2be87ce37"
},
"arm64_sonoma": {
"cellar": ":any_skip_relocation",
"url": "https://ghcr.io/v2/homebrew/core/awscli/blobs/sha256:a81a0befc34528cb3bf889c29f4b48d40f82f419d75e54da72648d73323775e6",
"sha256": "a81a0befc34528cb3bf889c29f4b48d40f82f419d75e54da72648d73323775e6"
"url": "https://ghcr.io/v2/homebrew/core/awscli/blobs/sha256:1241652148698d803bc9829a58a98eddff1fa814afff804550395081dec28823",
"sha256": "1241652148698d803bc9829a58a98eddff1fa814afff804550395081dec28823"
},
"arm64_ventura": {
"cellar": ":any_skip_relocation",
"url": "https://ghcr.io/v2/homebrew/core/awscli/blobs/sha256:d706e9a530a9d91afe8d6c530f2b15282af1b413079b51260fa8e0b626d02d32",
"sha256": "d706e9a530a9d91afe8d6c530f2b15282af1b413079b51260fa8e0b626d02d32"
"url": "https://ghcr.io/v2/homebrew/core/awscli/blobs/sha256:8004a01fe03d5f124ba5c574d342264ed54fb4d327370f28af195bdc9ad35bf1",
"sha256": "8004a01fe03d5f124ba5c574d342264ed54fb4d327370f28af195bdc9ad35bf1"
},
"sonoma": {
"cellar": ":any_skip_relocation",
"url": "https://ghcr.io/v2/homebrew/core/awscli/blobs/sha256:b80c11a4d77344f0069ddafe823fdbb81c91f97f0537506be6d5f1d57fab6959",
"sha256": "b80c11a4d77344f0069ddafe823fdbb81c91f97f0537506be6d5f1d57fab6959"
"url": "https://ghcr.io/v2/homebrew/core/awscli/blobs/sha256:9b246a9834230e6bda1930bda806434cc4d6206acf9a87dc53531756af41e5de",
"sha256": "9b246a9834230e6bda1930bda806434cc4d6206acf9a87dc53531756af41e5de"
},
"ventura": {
"cellar": ":any_skip_relocation",
"url": "https://ghcr.io/v2/homebrew/core/awscli/blobs/sha256:a6f5bec1b905c28da46e15b042a092c67bd18e0941c772684b2151ed2bfd1d5b",
"sha256": "a6f5bec1b905c28da46e15b042a092c67bd18e0941c772684b2151ed2bfd1d5b"
"url": "https://ghcr.io/v2/homebrew/core/awscli/blobs/sha256:2f11cb73572854c6af22fe14e30cf4cc2def8981d276b0c58584c80cf7b5610d",
"sha256": "2f11cb73572854c6af22fe14e30cf4cc2def8981d276b0c58584c80cf7b5610d"
},
"x86_64_linux": {
"cellar": ":any_skip_relocation",
"url": "https://ghcr.io/v2/homebrew/core/awscli/blobs/sha256:9bdbd00bceae9178f660b7fa55cad31073188c7d70edb87b1761ba83b8f33e11",
"sha256": "9bdbd00bceae9178f660b7fa55cad31073188c7d70edb87b1761ba83b8f33e11"
"url": "https://ghcr.io/v2/homebrew/core/awscli/blobs/sha256:bfb0081142aebec18034e252ade04ea354b7b334f56d47499d6a65eaa046c262",
"sha256": "bfb0081142aebec18034e252ade04ea354b7b334f56d47499d6a65eaa046c262"
}
}
}
},
"jq": {
"version": "1.7.1",
"bottle": {
"rebuild": 1,
"root_url": "https://ghcr.io/v2/homebrew/core",
"files": {
"arm64_sequoia": {
"cellar": ":any",
"url": "https://ghcr.io/v2/homebrew/core/jq/blobs/sha256:a10c82b07e393869d4467ad3e8ba26346d026b1ad3533d31dbb5e72abe9a7968",
"sha256": "a10c82b07e393869d4467ad3e8ba26346d026b1ad3533d31dbb5e72abe9a7968"
},
"arm64_sonoma": {
"cellar": ":any",
"url": "https://ghcr.io/v2/homebrew/core/jq/blobs/sha256:7d01bc414859db57e055c814daa10e9c586626381ea329862ad4300f9fee78ce",
"sha256": "7d01bc414859db57e055c814daa10e9c586626381ea329862ad4300f9fee78ce"
},
"arm64_ventura": {
"cellar": ":any",
"url": "https://ghcr.io/v2/homebrew/core/jq/blobs/sha256:b1a185e72ca020f08a8de22fabe1ad2425bf48d2e0378c5e07a6678020fa3e15",
"sha256": "b1a185e72ca020f08a8de22fabe1ad2425bf48d2e0378c5e07a6678020fa3e15"
},
"arm64_monterey": {
"cellar": ":any",
"url": "https://ghcr.io/v2/homebrew/core/jq/blobs/sha256:8f8c06332f413f5259b360ed65dc3ef21b5d3f2fff35160bc12367e53cbd06bf",
"sha256": "8f8c06332f413f5259b360ed65dc3ef21b5d3f2fff35160bc12367e53cbd06bf"
},
"sonoma": {
"cellar": ":any",
"url": "https://ghcr.io/v2/homebrew/core/jq/blobs/sha256:6bc01de99fd7f091b86880534842132a876f2d3043e3932ea75efc5f51c40aea",
"sha256": "6bc01de99fd7f091b86880534842132a876f2d3043e3932ea75efc5f51c40aea"
},
"ventura": {
"cellar": ":any",
"url": "https://ghcr.io/v2/homebrew/core/jq/blobs/sha256:03227348d3845fe16ed261ad020402c1f23c56e73f65799ce278af4bac63c799",
"sha256": "03227348d3845fe16ed261ad020402c1f23c56e73f65799ce278af4bac63c799"
},
"monterey": {
"cellar": ":any",
"url": "https://ghcr.io/v2/homebrew/core/jq/blobs/sha256:25aab2c539a41e4d67cd3d44353aac3cdd159ea815fec2b8dd82fbf038c559cc",
"sha256": "25aab2c539a41e4d67cd3d44353aac3cdd159ea815fec2b8dd82fbf038c559cc"
},
"x86_64_linux": {
"cellar": ":any_skip_relocation",
"url": "https://ghcr.io/v2/homebrew/core/jq/blobs/sha256:9559d8278cf20ad0294f2059855e1bc9d2bcabfd2bd5b5774c66006d1f201ad8",
"sha256": "9559d8278cf20ad0294f2059855e1bc9d2bcabfd2bd5b5774c66006d1f201ad8"
}
}
}
Expand Down
6 changes: 2 additions & 4 deletions cog.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ build:
- "libxext6"
- "wget"
python_packages:
- "git+https://github.com/huggingface/diffusers.git@249a9e48e8f8aac4356d5a285c8ba0c600a80f64"
- "diffusers"
- "torch==2.2"
- "transformers==4.41.2"
- "accelerate==0.31.0"
Expand All @@ -21,8 +21,6 @@ build:
- "boto3"
- "sqids"
- "numpy<2"
- "cog"

run:
- pip install "git+https://github.com/kylemclaren/cog.git@d486018abab50cb3d9b4edd4916f256009a8bd95"

predict: "predict.py:Predictor"
53 changes: 12 additions & 41 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import numpy as np
import torch
from cog import BasePredictor, Input, Path
from diffusers import FluxPipeline, FluxImg2ImgPipeline
from diffusers import StableDiffusionXLPipeline

# from diffusers.pipelines.stable_diffusion.safety_checker import (
# StableDiffusionSafetyChecker,
Expand All @@ -21,7 +21,7 @@
from PIL import ImageOps, Image


FLUX_MODEL_CACHE = "/src/flux-cache"
FLUX_MODEL_CACHE = "/src/"
FEATURE_EXTRACTOR = "./feature-extractor"
PUBLIC_BUCKET_NAME = os.getenv("PUBLIC_BUCKET_NAME", "xe-flux")

Expand Down Expand Up @@ -110,33 +110,19 @@ def setup(self):
print("Downloading model")
model_path = copy_from_tigris(destdir=FLUX_MODEL_CACHE)

print("Loading flux txt2img pipeline...")
print("Loading sdxl txt2img pipeline...")
dtype = torch.bfloat16

self.txt2img_pipe = FluxPipeline.from_pretrained(
model_path, torch_dtype=torch.bfloat16
)

self.txt2img_pipe.to("cuda")

print("Loading flux img2img pipeline...")

self.img2img_pipe = FluxImg2ImgPipeline.from_pretrained(
model_path, torch_dtype=torch.bfloat16
)

self.img2img_pipe.image_processor = VaeImageProcessor(
vae_scale_factor=16,
vae_latent_channels=self.img2img_pipe.vae.config.latent_channels,
)

self.img2img_pipe.to("cuda")
self.txt2img_pipe = StableDiffusionXLPipeline.from_pretrained(
model_path, torch_dtype=dtype,
).to("cuda")

print("setup took: ", time.time() - start)

def aspect_ratio_to_width_height(self, aspect_ratio: str):
aspect_ratios = {
"1:1": (1024, 1024),
"16:9": (1344, 768),
"16:9": (1344+64, 768+64),
"21:9": (1536, 640),
"3:2": (1216, 832),
"2:3": (832, 1216),
Expand All @@ -156,10 +142,6 @@ def predict(
description="Input prompt",
default="",
),
image: Path = Input(
description="Input image for img2img mode",
default=None,
),
aspect_ratio: str = Input(
description="Aspect ratio for the generated image",
choices=["1:1", "16:9", "21:9", "2:3", "3:2", "4:5", "5:4", "9:16", "9:21"],
Expand Down Expand Up @@ -210,21 +192,10 @@ def predict(

flux_kwargs = {}

if image:
print("img2img mode")
tmp_img = Image.open(image).convert("RGB")
width, height = resize_image_dimensions(tmp_img.size)
flux_kwargs["image"] = tmp_img.resize((width, height), Image.LANCZOS)
flux_kwargs["width"] = width
flux_kwargs["height"] = height
flux_kwargs["strength"] = prompt_strength
pipe = self.img2img_pipe
else:
print("txt2img mode")
width, height = resize_image_dimensions((width, height))
flux_kwargs["width"] = width
flux_kwargs["height"] = height
pipe = self.txt2img_pipe
width, height = resize_image_dimensions((width, height))
flux_kwargs["width"] = width
flux_kwargs["height"] = height
pipe = self.txt2img_pipe

generator = torch.Generator("cuda").manual_seed(seed)

Expand Down
2 changes: 1 addition & 1 deletion scripts/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def download_batch(batch) -> int:
def copy_from_tigris(
model_name: str = os.getenv("MODEL_PATH", "flux-1.dev"),
bucket_name: str = os.getenv("MODEL_BUCKET_NAME", "cipnahubakfu"),
destdir: str = "flux-cache",
destdir: str = ".",
n_cpus: int = os.cpu_count()
):
"""Copy files from Tigris to the destination folder. This will be done in parallel.
Expand Down
22 changes: 19 additions & 3 deletions scripts/prepare_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,25 @@ def batcher(iterable: Iterable, batch_size: int) -> Generator[List, None, None]:

def fetch_and_save_model(model_name, destdir):
import torch
from diffusers import FluxPipeline

pipe = FluxPipeline.from_pretrained(model_name, torch_dtype=torch.bfloat16)
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel
from huggingface_hub import hf_hub_download

dtype = torch.bfloat16
pipe = None

match model_name:
case "ByteDance/SDXL-Lightning":
from safetensors.torch import load_file
base = "stabilityai/stable-diffusion-xl-base-1.0"
repo = "ByteDance/SDXL-Lightning"
ckpt = "sdxl_lightning_4step_unet.safetensors"
unet = UNet2DConditionModel.from_config(base, subfolder="unet")
unet.load_state_dict(load_file(hf_hub_download(repo, ckpt)))
pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=dtype, variant="fp16")
case _:
pipe = StableDiffusionXLPipeline.from_pretrained(model_name, torch_dtype=dtype)
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)

pipe.save_pretrained(destdir, safe_serialization=True)


Expand Down

0 comments on commit 02c5148

Please sign in to comment.