In [None]:
import torch
import treescope
import numpy as np

from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
from diffusers.utils import pt_to_pil
from evotorch import Problem
from evotorch.decorators import vectorized
from evotorch.algorithms import CMAES, SNES, CEM
from evotorch.logging import StdOutLogger, PandasLogger
from diffusion_pt import diffusion_sample
from fitness_fn import brightness, clip_fitness_fn, compose_fitness_fns
from vectorized_problem import VectorizedProblem
import matplotlib.pyplot as plt

treescope.register_as_default()
treescope.basic_interactive_setup()

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model_id = "stable-diffusion-v1-5/stable-diffusion-v1-5"
model_id = "stabilityai/stable-diffusion-2-base"
pipeline = DiffusionPipeline.from_pretrained(
    model_id,
    # scheduler=scheduler,
    use_safetensors=True,
    cache_dir="/scratch/gilbreth/pjajal/hf_datasets"

).to(device)

In [3]:
problem_split = 2
num_inference_steps = 15
sample_fn, latents, num_inference_steps = diffusion_sample(pipeline, ["a picture of a black dog"], num_inference_steps, torch.Generator(device=device).manual_seed(23), guidance_scale=7.5, batch_size=1)
clip_fit = clip_fitness_fn("openai/clip-vit-large-patch14", ["a picture of a cat"], cache_dir="/scratch/gilbreth/pjajal/hf_datasets")
fit = compose_fitness_fns([brightness], [1])

In [12]:
def fitness(sample_fn, latent_shape, fitness_fn, num_inference_steps, mean_scale=1e-2):
  b, c, h, w = latent_shape
  random_vec = 1e-6 * torch.randn(b, h, w, c).to(device, dtype=torch.float32)
  def _fitness(x):
    x = x.reshape(-1, c * 2)
    mean, cov_diag = x.chunk(2, dim=-1)
    mean = mean.to(device, dtype=torch.float32).unsqueeze(1).unsqueeze(2)
    cov_diag = cov_diag.to(device, dtype=torch.float32).unsqueeze(1).unsqueeze(2)
    x = mean_scale * mean + random_vec * cov_diag
    x = x.permute(0, 3, 1, 2)
    samples = sample_fn(x)
    return torch.cat([fitness_fn(sample.unsqueeze(0)) for sample in samples], dim=0)
  return _fitness, random_vec

In [None]:
mean_scale = 0.01
fitness_fn, random_vec = fitness(sample_fn, latents.shape, fit, num_inference_steps, mean_scale=mean_scale)
problem = VectorizedProblem("max", fitness_fn, solution_length=latents.shape[1] * 2 , initial_bounds=(-2, 2), dtype=np.dtype('float32'), splits=problem_split)
# searcher = CMAES(problem, stdev_init=2)
searcher = SNES(problem, stdev_init=10)
logger = StdOutLogger(searcher)
pandas_logger = PandasLogger(searcher)

In [None]:
a = sample_fn(random_vec.permute(0, 3, 1, 2))
plt.imshow(pt_to_pil(a)[0])
plt.show()
for step in range(25):
  searcher.step()
  best_idx = searcher.population.argbest()
  x = searcher.population[best_idx].values

  x = x.reshape(-1, 8)
  mean, cov_diag = x.chunk(2, dim=-1)
  mean = mean.to(device, dtype=torch.float32).unsqueeze(1).unsqueeze(2)
  cov_diag = cov_diag.to(device, dtype=torch.float32).unsqueeze(1).unsqueeze(2)
  x = mean_scale * mean + random_vec * cov_diag
  # x = (mean_scale * mean + random_vec)
  x = x.permute(0, 3, 1, 2)
  a = sample_fn(x)
  plt.imshow(pt_to_pil(a)[0])
  plt.show()

In [None]:
my_data_frame = pandas_logger.to_dataframe()
my_data_frame["median_eval"].plot()
plt.show()

In [31]:
def fitness(sample_fn, latent_shape, fitness_fn, num_inference_steps, mean_scale=1e-2):
  b, c, h, w = latent_shape
  def _fitness(x):
    x = x.reshape(-1, *latent_shape[1:]).to(device) * mean_scale
    samples = sample_fn(x)
    return torch.cat([fitness_fn(sample.unsqueeze(0)) for sample in samples], dim=0)
  return _fitness

In [None]:
mean_scale = 0.15
fitness_fn = fitness(sample_fn, latents.shape, fit, num_inference_steps, mean_scale=mean_scale)
problem = Problem("max", fitness_fn, solution_length=np.prod(latents.shape) , initial_bounds=(-1.5, 1.5), dtype=np.dtype('float32'))
problem = VectorizedProblem("max", fitness_fn, solution_length=np.prod(latents.shape) , initial_bounds=(-2, 2), dtype=np.dtype('float32'), splits=4)
# searcher = SNES(problem, stdev_init=0.25)
searcher = CMAES(problem, stdev_init=1, separable=True)
logger = StdOutLogger(searcher)
pandas_logger = PandasLogger(searcher)

In [None]:
a = sample_fn()
plt.imshow(pt_to_pil(a)[0])
plt.show()
for step in range(25):
  searcher.step()
  best_idx = searcher.population.argbest()
  x = searcher.population[best_idx].values

  a = sample_fn(x.reshape(-1, *latents.shape[1:]).to(device) * mean_scale)
  plt.imshow(pt_to_pil(a)[0])
  plt.show()