In [3]:
import treescope
import numpy as np
import jax
import jax.numpy as jnp
import optax
import torch
from evosax.algorithms import Open_ES, MA_ES, DifferentialEvolution

from diffusers import LatentConsistencyModelPipeline
from noise_injection_pipelines import LCMSamplingPipeline, noise
from fitness import brightness

treescope.basic_interactive_setup()

In [4]:
pipeline = LatentConsistencyModelPipeline.from_pretrained(
    "SimianLuo/LCM_Dreamshaper_v7",
    device_map="balanced",
    torch_dtype=torch.float16,
    use_safetensors=True,
    safety_checker=None,
)

Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

You have disabled the safety checker for <class 'diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline'> by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 .


In [5]:
sample_fn = LCMSamplingPipeline(
    pipeline=pipeline,
    prompt="",
    num_inference_steps=4,
    classifier_free_guidance=True,
    guidance_scale=7.5,
    generator=torch.Generator("mps").manual_seed(0),
    add_noise=False,
    height=512,
    width=512,
)

In [6]:
sample_fn.regenerate_latents()
sample_fn.rembed_text("An orange cat sitting on a couch")

In [None]:
sample_fn()[0]

In [None]:
es = Open_ES(
    population_size=4, solution=jnp.zeros((4, 64, 64)), optimizer=optax.sgd(1.0)
)

In [None]:
es = MA_ES(
    population_size=8,
    solution=jnp.zeros((4, 64, 64)),
)

In [8]:
es = DifferentialEvolution(
    population_size=8,
    solution=jnp.zeros((4, 64, 64)),
)

In [9]:
es.default_params

In [None]:
# Initialize state
key = jax.random.key(0)
key, pop_key = jax.random.split(key)
params = es.default_params
population = jax.random.normal(pop_key, (8, 4, 64, 64))

torch_population = torch.from_dlpack(population)
initial_fitness = 

state = es.init(key, pop,  params)

In [20]:
state

In [19]:
best_sols = []
for i in range(10):
    key, key_ask, key_eval = jax.random.split(key, 3)
    population, state = es.ask(key_ask, state, params)

    # Evaluate the population
    torch_population = torch.from_dlpack(population)
    imgs = []
    for latents in torch_population.split(2):
        img = sample_fn(noise_injection=latents)
        imgs.extend(img)
    fitnesses = [brightness(img) for img in imgs]
    idx_max = np.argmax(fitnesses)
    best_img = imgs[idx_max]

    fitnesses = jnp.array(fitnesses).squeeze()
    # Tell the ES about the fitnesses
    state, metrics = es.tell(key_eval, population, fitnesses, state, params)
    # Get the best solution
    best_sol = state.best_solution
    best_sols.append((best_img, state.best_fitness))
    display(best_img)
    print(f"Best fitness: {state.best_fitness}")

ValueError: Cannot take a larger sample (size 2) than population (size 1) when 'replace=False'