In [None]:
!nvidia-smi

In [None]:
import os
import json
import time
import torch
from PIL import Image
from tqdm import tqdm

from safetensors.torch import save_file
from src.pipeline import FluxPipeline
from src.transformer_flux import FluxTransformer2DModel
from src.lora_helper import set_single_lora, set_multi_lora, unset_lora

torch.cuda.set_device(1)

class ImageProcessor:
    def __init__(self, path):
        device = "cuda"
        self.pipe = FluxPipeline.from_pretrained(path, torch_dtype=torch.bfloat16, device=device)
        transformer = FluxTransformer2DModel.from_pretrained(path, subfolder="transformer",torch_dtype=torch.bfloat16, device=device)
        self.pipe.transformer = transformer
        self.pipe.to(device)
        
    def clear_cache(self, transformer):
        for name, attn_processor in transformer.attn_processors.items():
            attn_processor.bank_kv.clear()
        
    def process_image(self, prompt='', subject_imgs=[], spatial_imgs=[], height = 768, width = 768, output_path=None, seed=42):
        if len(spatial_imgs)>0:
            spatial_ls = [Image.open(image_path).convert("RGB") for image_path in spatial_imgs]
        else:
            spatial_ls = []
        if len(subject_imgs)>0:
            subject_ls = [Image.open(image_path).convert("RGB") for image_path in subject_imgs]
        else:
            subject_ls = []

        prompt = prompt
        image = self.pipe(
            prompt,
            height=int(height),
            width=int(width),
            guidance_scale=3.5,
            num_inference_steps=25,
            max_sequence_length=512,
            generator=torch.Generator("cpu").manual_seed(seed), 
            subject_images=subject_ls,
            spatial_images=spatial_ls,
            cond_size=512,
        ).images[0]
        self.clear_cache(self.pipe.transformer)
        image.show()
        if output_path:
            image.save(output_path)

In [None]:
### models path ###
# spatial model
base_path = "FLUX.1-dev"  # your flux model path
lora_path = "./models" # your lora folder path
canny_path = lora_path + "/canny.safetensors"
depth_path = lora_path + "/depth.safetensors"
openpose_path = lora_path + "/pose.safetensors"
inpainting_path = lora_path + "/inpainting.safetensors"
hedsketch_path = lora_path + "/hedsketch.safetensors"
seg_path = lora_path + "/seg.safetensors"
# subject model
subject_path = lora_path + "/subject.safetensors"

# init image processor
processor = ImageProcessor(base_path)

for single condition

In [None]:
# set lora
path = depth_path  # single control model path
lora_weights=[1]  # lora weights for each control model
set_single_lora(processor.pipe.transformer, path, lora_weights=lora_weights,cond_size=512)

# infer
prompt='a cafe bar'
spatial_imgs=["./test_imgs/depth.png"]
height = 1024
width = 1024
processor.process_image(prompt=prompt, subject_imgs=[], spatial_imgs=spatial_imgs, height=height, width=width, seed=11)

for multi condition

In [None]:
# set lora
paths = [subject_path, inpainting_path]  # multi control model paths
lora_weights=[[1],[1]]  # lora weights for each control model
set_multi_lora(processor.pipe.transformer, paths, lora_weights=lora_weights, cond_size=512)

# infer
prompt='A SKS on the car'
spatial_imgs=["./test_imgs/subject_1.png"]
subject_imgs=["./test_imgs/inpainting.png"]
height = 1024
width = 1024
processor.process_image(prompt=prompt, subject_imgs=subject_imgs, spatial_imgs=spatial_imgs, height=height, width=width, seed=42)