In [1]:
import torch
import numpy as np
from einops import einsum

from diffusers import DiffusionPipeline
from evotorch import Problem
from evotorch.decorators import vectorized
from evotorch.algorithms import CMAES, SNES, CEM
from evotorch.logging import StdOutLogger, PandasLogger
from noise_injection_pipelines.sampling_pipelines import SDXLSamplingPipeline
from fitness.fitness_fn import (
    brightness,
    clip_fitness_fn,
    compose_fitness_fns,
    relative_luminance,
    Novelty,
    pickscore_fitness_fn,
    aesthetic_fitness_fn,
)
from evo.vectorized_problem import VectorizedProblem
from diffusers.utils import pt_to_pil, numpy_to_pil
import matplotlib.pyplot as plt

from noise_injection_pipelines.noise_injection import rotational_transform

In [None]:
model_name = "stabilityai/sdxl-turbo"
pipe = DiffusionPipeline.from_pretrained(
    model_name, torch_dtype=torch.float16, use_safetensors=True
).to("cuda")
pipe.enable_xformers_memory_efficient_attention()


In [9]:
problem_split = 6
num_inference_steps = 15
noise_scale = 1.0
sample_fn = SDXLSamplingPipeline(
    pipe,
    prompt="A beautiful landscape painting",
    num_inference_steps=num_inference_steps,
    generator=torch.Generator(device="cuda").manual_seed(0),
)
clip_fit = clip_fitness_fn(
    "openai/clip-vit-large-patch14", ["A beautiful landscape painting"], dtype=pipe.dtype
)
fit = compose_fitness_fns([clip_fit], [1])

In [None]:
mean_scale = 0.01
initial_bounds = (-1, 1)
injection_steps = num_inference_steps
sample_fn.inject_multiple_noise_scale = noise_scale

fitness_fn, inner_fn, centroid, solution_length = rotational_transform(
    sample_fn,
    fit,
    sample_fn.latents.shape,
    device=pipe.device,
    center=sample_fn.latents,
    dtype=pipe.dtype,
)

problem = VectorizedProblem(
    "max",
    fitness_fn,
    solution_length=solution_length,
    initial_bounds=initial_bounds,
    dtype=np.dtype("float32"),
    splits=problem_split,
    initialization=None,
)
searcher = CMAES(problem, stdev_init=1, separable=True, csa_squared=True)
# searcher = SNES(problem, stdev_init=10)
logger = StdOutLogger(searcher)
pandas_logger = PandasLogger(searcher)
print(f"pop. size: {searcher.popsize}")

In [None]:
a = sample_fn()
plt.imshow(numpy_to_pil(a)[0])
plt.show()
with torch.no_grad():
    for step in range(200):
        searcher.step()
        best_idx = searcher.population.argbest()
        x = searcher.population[best_idx].values

        a = inner_fn(x)
        plt.imshow(numpy_to_pil(a)[0])
        plt.show()