In [2]:
import torch
import numpy as np
from einops import einsum

from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
from diffusers.utils import pt_to_pil
from evotorch import Problem
from evotorch.decorators import vectorized
from evotorch.algorithms import CMAES, SNES, CEM
from evotorch.logging import StdOutLogger, PandasLogger
from noise_injection_pipelines.diffusion_pt import diffusion_sample, DiffusionSample
from fitness.fitness_fn import brightness, clip_fitness_fn, compose_fitness_fns, relative_luminance, Novelty, pickscore_fitness_fn, aesthetic_fitness_fn
from evo.vectorized_problem import VectorizedProblem
import matplotlib.pyplot as plt

from noise_injection_pipelines.noise_injection import rotational_transform, rotational_transform_inject_multiple, multi_axis_rotational_transform, svd_rot_transform, multi_axis_svd_rot_transform

In [3]:
model_id = "stabilityai/stable-diffusion-3.5-medium"
pipeline = DiffusionPipeline.from_pretrained(model_id, use_safetensors=True, torch_dtype=torch.float16)

model_index.json:   0%|          | 0.00/706 [00:00<?, ?B/s]

Fetching 26 files:   0%|          | 0/26 [00:00<?, ?it/s]

text_encoder_2/config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

text_encoder/config.json:   0%|          | 0.00/574 [00:00<?, ?B/s]

scheduler/scheduler_config.json:   0%|          | 0.00/141 [00:00<?, ?B/s]

text_encoder_3/config.json:   0%|          | 0.00/740 [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/4.53G [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/247M [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

tokenizer/special_tokens_map.json:   0%|          | 0.00/588 [00:00<?, ?B/s]

tokenizer/merges.txt:   0%|          | 0.00/525k [00:00<?, ?B/s]

tokenizer/tokenizer_config.json:   0%|          | 0.00/705 [00:00<?, ?B/s]

(…)t_encoder_3/model.safetensors.index.json:   0%|          | 0.00/19.9k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.39G [00:00<?, ?B/s]

tokenizer/vocab.json:   0%|          | 0.00/1.06M [00:00<?, ?B/s]

tokenizer_2/special_tokens_map.json:   0%|          | 0.00/576 [00:00<?, ?B/s]

tokenizer_2/tokenizer_config.json:   0%|          | 0.00/856 [00:00<?, ?B/s]

tokenizer_3/special_tokens_map.json:   0%|          | 0.00/2.54k [00:00<?, ?B/s]

tokenizer_3/tokenizer.json:   0%|          | 0.00/2.42M [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

transformer/config.json:   0%|          | 0.00/524 [00:00<?, ?B/s]

tokenizer_3/tokenizer_config.json:   0%|          | 0.00/20.6k [00:00<?, ?B/s]

diffusion_pytorch_model.safetensors:   0%|          | 0.00/4.94G [00:00<?, ?B/s]

vae/config.json:   0%|          | 0.00/809 [00:00<?, ?B/s]

diffusion_pytorch_model.safetensors:   0%|          | 0.00/168M [00:00<?, ?B/s]

Loading pipeline components...:   0%|          | 0/9 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers


In [4]:
pipeline

StableDiffusion3Pipeline {
  "_class_name": "StableDiffusion3Pipeline",
  "_diffusers_version": "0.32.1",
  "_name_or_path": "stabilityai/stable-diffusion-3.5-medium",
  "feature_extractor": [
    null,
    null
  ],
  "image_encoder": [
    null,
    null
  ],
  "scheduler": [
    "diffusers",
    "FlowMatchEulerDiscreteScheduler"
  ],
  "text_encoder": [
    "transformers",
    "CLIPTextModelWithProjection"
  ],
  "text_encoder_2": [
    "transformers",
    "CLIPTextModelWithProjection"
  ],
  "text_encoder_3": [
    "transformers",
    "T5EncoderModel"
  ],
  "tokenizer": [
    "transformers",
    "CLIPTokenizer"
  ],
  "tokenizer_2": [
    "transformers",
    "CLIPTokenizer"
  ],
  "tokenizer_3": [
    "transformers",
    "T5TokenizerFast"
  ],
  "transformer": [
    "diffusers",
    "SD3Transformer2DModel"
  ],
  "vae": [
    "diffusers",
    "AutoencoderKL"
  ]
}

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
CACHE_DIR = "/scratch/gilbreth/pjajal/hf_datasets"

model_id = "stable-diffusion-v1-5/stable-diffusion-v1-5"
# model_id = "stabilityai/stable-diffusion-2-base"
pipeline = DiffusionPipeline.from_pretrained(
    model_id,
    # scheduler=scheduler,
    use_safetensors=True,
    torch_dtype=torch.float16,
    cache_dir=CACHE_DIR,

).to(device)

In [None]:
problem_split = 2
num_inference_steps = 15
noise_scale = 1.0
sample_fn = DiffusionSample(pipeline, ["a picture of a dog"], num_inference_steps, torch.Generator(device=device).manual_seed(98), guidance_scale=7.5, batch_size=1, inject_multiple=True, inject_multiple_noise_scale=noise_scale)
centers, num_inference_steps, dtype = sample_fn.noise_injection_args()

clip_fit = clip_fitness_fn("openai/clip-vit-large-patch14", ["drawing of a cat"], cache_dir=CACHE_DIR, dtype=dtype)
# novelty = Novelty("dino_small", top_k=20, device=device, cache_dir=CACHE_DIR)
# pick = pickscore_fitness_fn(["a picture of an orange dog"], cache_dir=CACHE_DIR, device=device)
# pick = lambda x : 0
# aes = aesthetic_fitness_fn(CACHE_DIR, device=device, dtype=dtype)
fit = compose_fitness_fns([clip_fit], [1])

In [None]:
mean_scale = 0.01
initial_bounds=(-1,1)
injection_steps = num_inference_steps
noise_scale = [1.0, 0.5, 0.25, 0.1]
noise_scale = (np.linspace(1, 0, num=num_inference_steps) ** 8).tolist()
sample_fn.inject_multiple_noise_scale = noise_scale
fitness_fn, inner_fn, centroid, solution_length = rotational_transform_inject_multiple(sample_fn, fit, sample_fn.latents.shape, device, injection_steps=injection_steps, center=centers, mean_scale=mean_scale, dtype=dtype)

problem = VectorizedProblem("max", fitness_fn, solution_length=solution_length, initial_bounds=initial_bounds, dtype=np.dtype('float32'), splits=problem_split, initialization=None)
searcher = CMAES(problem, stdev_init=1, separable=True, csa_squared=True)
# searcher = SNES(problem, stdev_init=10)
logger = StdOutLogger(searcher)
pandas_logger = PandasLogger(searcher)
print(f"pop. size: {searcher.popsize}")

In [None]:
a = sample_fn()
plt.imshow(pt_to_pil(a)[0])
plt.show()
with torch.no_grad():
    for step in range(200):
      searcher.step()
      best_idx = searcher.population.argbest()
      x = searcher.population[best_idx].values

      a = inner_fn(x)
      plt.imshow(pt_to_pil(a)[0])
      plt.show()

In [None]:
mean_scale = 0.01
initial_bounds=(-1,1)
fitness_fn, inner_fn, centroid, solution_length = multi_axis_rotational_transform(sample_fn, fit, sample_fn.latents.shape, device, center=centers, mean_scale=mean_scale, dtype=dtype)

problem = VectorizedProblem("max", fitness_fn, solution_length=solution_length, initial_bounds=initial_bounds, dtype=np.dtype('float32'), splits=problem_split, initialization=None)
searcher = CMAES(problem, stdev_init=1, separable=True, csa_squared=True)
# searcher = SNES(problem, stdev_init=10)
logger = StdOutLogger(searcher)
pandas_logger = PandasLogger(searcher)
print(f"pop. size: {searcher.popsize}")

In [None]:
a = sample_fn()
plt.imshow(pt_to_pil(a)[0])
plt.show()
with torch.no_grad():
    for step in range(200):
      searcher.step()
      best_idx = searcher.population.argbest()
      x = searcher.population[best_idx].values

      a = inner_fn(x)
      plt.imshow(pt_to_pil(a)[0])
      plt.show()

In [None]:
mean_scale = 0.01
initial_bounds=(-1,1)
fitness_fn, inner_fn, centroid, solution_length = svd_rot_transform(sample_fn, fit, sample_fn.latents.shape, device, center=centers, mean_scale=mean_scale, bound=0.01, dtype=dtype)

problem = VectorizedProblem("max", fitness_fn, solution_length=solution_length, initial_bounds=initial_bounds, dtype=np.dtype('float32'), splits=problem_split, initialization=None)
searcher = CMAES(problem, stdev_init=5, separable=True, csa_squared=True)
# searcher = SNES(problem, stdev_init=10)
logger = StdOutLogger(searcher)
pandas_logger = PandasLogger(searcher)
print(f"pop. size: {searcher.popsize}")

In [None]:
a = sample_fn()
plt.imshow(pt_to_pil(a)[0])
plt.show()
for step in range(200):
  searcher.step()
  best_idx = searcher.population.argbest()
  x = searcher.population[best_idx].values
    
  a = inner_fn(x)
  plt.imshow(pt_to_pil(a)[0])
  plt.show()

In [None]:
mean_scale = 0.0001
initial_bounds=(-1,1)
fitness_fn, inner_fn, centroid, solution_length = multi_axis_svd_rot_transform(sample_fn, fit, sample_fn.latents.shape, device, center=centers, mean_scale=mean_scale, bound=0.001, dtype=dtype)

problem = VectorizedProblem("max", fitness_fn, solution_length=solution_length, initial_bounds=initial_bounds, dtype=np.dtype('float32'), splits=problem_split, initialization=None)
searcher = CMAES(problem, stdev_init=1, separable=True, csa_squared=True)
# searcher = SNES(problem, stdev_init=10)
logger = StdOutLogger(searcher)
pandas_logger = PandasLogger(searcher)
print(f"pop. size: {searcher.popsize}")

In [None]:
a = sample_fn()
plt.imshow(pt_to_pil(a)[0])
plt.show()
with torch.no_grad():
    for step in range(200):
      searcher.step()
      best_idx = searcher.population.argbest()
      x = searcher.population[best_idx].values

      a = inner_fn(x)
      plt.imshow(pt_to_pil(a)[0])
      plt.show()