In [None]:
import torch
import numpy as np
from einops import einsum

from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
from diffusers.utils import pt_to_pil
from evotorch import Problem
from evotorch.decorators import vectorized
from evotorch.algorithms import CMAES, SNES, CEM
from evotorch.logging import StdOutLogger, PandasLogger
from diffusion_pt import diffusion_sample, DiffusionSample
from fitness_fn import brightness, clip_fitness_fn, compose_fitness_fns, relative_luminance, Novelty, pickscore_fitness_fn, aesthetic_fitness_fn
from vectorized_problem import VectorizedProblem
import matplotlib.pyplot as plt

from noise_injection import rotational_transform, rotational_transform_inject_multiple, multi_axis_rotational_transform, svd_rot_transform, multi_axis_svd_rot_transform

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
CACHE_DIR = "/scratch/gilbreth/pjajal/hf_datasets"

model_id = "stable-diffusion-v1-5/stable-diffusion-v1-5"
# model_id = "stabilityai/stable-diffusion-2-base"
pipeline = DiffusionPipeline.from_pretrained(
    model_id,
    # scheduler=scheduler,
    use_safetensors=True,
    torch_dtype=torch.float16,
    cache_dir=CACHE_DIR,

).to(device)

In [None]:
problem_split = 2
num_inference_steps = 15
noise_scale = 1.0
sample_fn = DiffusionSample(pipeline, ["a picture of a dog"], num_inference_steps, torch.Generator(device=device).manual_seed(98), guidance_scale=7.5, batch_size=1, inject_multiple=True, inject_multiple_noise_scale=noise_scale)
centers, num_inference_steps, dtype = sample_fn.noise_injection_args()

clip_fit = clip_fitness_fn("openai/clip-vit-large-patch14", ["drawing of a cat"], cache_dir=CACHE_DIR, dtype=dtype)
# novelty = Novelty("dino_small", top_k=20, device=device, cache_dir=CACHE_DIR)
# pick = pickscore_fitness_fn(["a picture of an orange dog"], cache_dir=CACHE_DIR, device=device)
# pick = lambda x : 0
# aes = aesthetic_fitness_fn(CACHE_DIR, device=device, dtype=dtype)
fit = compose_fitness_fns([clip_fit], [1])

In [None]:
mean_scale = 0.01
initial_bounds=(-1,1)
injection_steps = num_inference_steps
noise_scale = [1.0, 0.5, 0.25, 0.1]
noise_scale = (np.linspace(1, 0, num=num_inference_steps) ** 8).tolist()
sample_fn.inject_multiple_noise_scale = noise_scale
fitness_fn, inner_fn, centroid, solution_length = rotational_transform_inject_multiple(sample_fn, fit, sample_fn.latents.shape, device, injection_steps=injection_steps, center=centers, mean_scale=mean_scale, dtype=dtype)

problem = VectorizedProblem("max", fitness_fn, solution_length=solution_length, initial_bounds=initial_bounds, dtype=np.dtype('float32'), splits=problem_split, initialization=None)
searcher = CMAES(problem, stdev_init=1, separable=True, csa_squared=True)
# searcher = SNES(problem, stdev_init=10)
logger = StdOutLogger(searcher)
pandas_logger = PandasLogger(searcher)
print(f"pop. size: {searcher.popsize}")

In [None]:
a = sample_fn()
plt.imshow(pt_to_pil(a)[0])
plt.show()
with torch.no_grad():
    for step in range(200):
      searcher.step()
      best_idx = searcher.population.argbest()
      x = searcher.population[best_idx].values

      a = inner_fn(x)
      plt.imshow(pt_to_pil(a)[0])
      plt.show()

In [None]:
mean_scale = 0.01
initial_bounds=(-1,1)
fitness_fn, inner_fn, centroid, solution_length = multi_axis_rotational_transform(sample_fn, fit, sample_fn.latents.shape, device, center=centers, mean_scale=mean_scale, dtype=dtype)

problem = VectorizedProblem("max", fitness_fn, solution_length=solution_length, initial_bounds=initial_bounds, dtype=np.dtype('float32'), splits=problem_split, initialization=None)
searcher = CMAES(problem, stdev_init=1, separable=True, csa_squared=True)
# searcher = SNES(problem, stdev_init=10)
logger = StdOutLogger(searcher)
pandas_logger = PandasLogger(searcher)
print(f"pop. size: {searcher.popsize}")

In [None]:
a = sample_fn()
plt.imshow(pt_to_pil(a)[0])
plt.show()
with torch.no_grad():
    for step in range(200):
      searcher.step()
      best_idx = searcher.population.argbest()
      x = searcher.population[best_idx].values

      a = inner_fn(x)
      plt.imshow(pt_to_pil(a)[0])
      plt.show()

In [None]:
mean_scale = 0.01
initial_bounds=(-1,1)
fitness_fn, inner_fn, centroid, solution_length = svd_rot_transform(sample_fn, fit, sample_fn.latents.shape, device, center=centers, mean_scale=mean_scale, bound=0.01, dtype=dtype)

problem = VectorizedProblem("max", fitness_fn, solution_length=solution_length, initial_bounds=initial_bounds, dtype=np.dtype('float32'), splits=problem_split, initialization=None)
searcher = CMAES(problem, stdev_init=5, separable=True, csa_squared=True)
# searcher = SNES(problem, stdev_init=10)
logger = StdOutLogger(searcher)
pandas_logger = PandasLogger(searcher)
print(f"pop. size: {searcher.popsize}")

In [None]:
a = sample_fn()
plt.imshow(pt_to_pil(a)[0])
plt.show()
for step in range(200):
  searcher.step()
  best_idx = searcher.population.argbest()
  x = searcher.population[best_idx].values
    
  a = inner_fn(x)
  plt.imshow(pt_to_pil(a)[0])
  plt.show()

In [None]:
mean_scale = 0.0001
initial_bounds=(-1,1)
fitness_fn, inner_fn, centroid, solution_length = multi_axis_svd_rot_transform(sample_fn, fit, sample_fn.latents.shape, device, center=centers, mean_scale=mean_scale, bound=0.001, dtype=dtype)

problem = VectorizedProblem("max", fitness_fn, solution_length=solution_length, initial_bounds=initial_bounds, dtype=np.dtype('float32'), splits=problem_split, initialization=None)
searcher = CMAES(problem, stdev_init=1, separable=True, csa_squared=True)
# searcher = SNES(problem, stdev_init=10)
logger = StdOutLogger(searcher)
pandas_logger = PandasLogger(searcher)
print(f"pop. size: {searcher.popsize}")

In [None]:
a = sample_fn()
plt.imshow(pt_to_pil(a)[0])
plt.show()
with torch.no_grad():
    for step in range(200):
      searcher.step()
      best_idx = searcher.population.argbest()
      x = searcher.population[best_idx].values

      a = inner_fn(x)
      plt.imshow(pt_to_pil(a)[0])
      plt.show()