diff --git a/predict.py b/predict.py index 1bc5dc7..35dfcf6 100644 --- a/predict.py +++ b/predict.py @@ -79,10 +79,10 @@ def download_weights(url: str, dest: str) -> None: class Predictor(BasePredictor): def setup(self) -> None: """Load the model into memory to make running multiple predictions efficient""" - + if not os.path.exists(MODEL_CACHE): os.makedirs(MODEL_CACHE) - + model_files = [ "models--InstantX--SD3-Controlnet-Canny.tar", "models--stabilityai--stable-diffusion-3-medium-diffusers.tar", @@ -154,9 +154,13 @@ def predict( print(f"Using seed: {seed}") + # Preprocess the input image + input_image = Image.open(str(image_in)) + if input_image.mode != "RGB": + input_image = input_image.convert("RGB") + # Canny preprocessing - image_to_canny = load_image(str(image_in)) - image_to_canny = np.array(image_to_canny) + image_to_canny = np.array(input_image) image_to_canny = cv2.Canny(image_to_canny, 100, 200) image_to_canny = image_to_canny[:, :, None] image_to_canny = np.concatenate( @@ -182,6 +186,7 @@ def predict( image = image.resize((w, h), Image.LANCZOS) # Save the image with the specified format and quality + image = image.convert("RGB") extension = output_format.lower() extension = "jpeg" if extension == "jpg" else extension output_path = f"output.{extension}"