In [None]:
import os

import numpy as np
import torch

from diffusers import StableDiffusionPipeline, AutoPipelineForImage2Image

from PIL import Image
import matplotlib.pyplot as plt

In [None]:
def seed_all(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)

def grid_show(images, rows=3):

    number_images = len(images)
    height, width = images[0].size
    
    columns = int(np.ceil(number_images / rows))
    grid = np.zeros((height*rows,width*columns,3)) #Image.new("RGB", size=(width*columns, height*rows))
    for ii, image in enumerate(images):
        grid[ii//columns*height:ii//columns*height+height, \
                ii%columns*width:ii%columns*width+width] = image
    fig, ax = plt.subplots(1,1, figsize=(3*columns, 3*rows))
    ax.imshow(grid / grid.max())
    return grid, fig, ax

def callback_stash_latents(ii, tt, latents):
    # adapted from https://github.com/fastai/diffusion-nbs/stable_diffusion.ipynb
    latents = 1.0 / 0.18215 * latents
    image = pipe.vae.decode(latents).sample[0]
    image = (image / 2. + 0.5).cpu().permute(1,2,0).numpy()
    image = np.clip(image, 0, 1.0)
    images.extend(pipe.numpy_to_pil(image))

my_seed = 113
my_seed < 2**32

In [None]:
if (1):
    #Run CompVis/stable-diffusion-v1-4 on CPU. Replace lines with commented alternatives for GPU inference. 
    pipe_name = "CompVis/stable-diffusion-v1-4"
    my_dtype = torch.float32 #torch.float16
    my_device = torch.device("cpu") #torch.device("cuda")
    pipe = StableDiffusionPipeline.from_pretrained(pipe_name, torch_dtype=my_dtype).to(my_device)
    #pipe = StableDiffusionPipeline.from_pretrained(pipe_name, torch_dtype=my_dtype, variant="fp16").to(my_device)
elif (0):
    #Run CompVis/stable-diffusion-v1-4 on GPU
    pipe_name = "CompVis/stable-diffusion-v1-4"
    my_dtype = torch.float16
    my_device = torch.device("cuda")
    my_variant = "fp16"
    pipe = StableDiffusionPipeline.from_pretrained(pipe_name, variant=my_variant, torch_dtype=my_dtype).to(my_device)
else:
    pipe_name = "stabilityai/stable-diffusion-xl-base-1.0"
    my_dtype = torch.float32
    my_device = torch.device("cpu")
    pipe = StableDiffusionPipeline.from_pretrained(pipe_name, torch_dtype=my_dtype).to(my_device)

In [None]:
guidance_images = []
my_prompt = "Two mice, best friends, enjoy a cozy cup of tea together in a cozy cottage, illustration, characters, watercolor"
    
for guidance in [0.25, 0.5, 1., 2.0, 4.0, 6.0, 8.0, 10.0, 14.0]:
    seed_all(my_seed)
    my_output = pipe(my_prompt, num_inference_steps=50, num_images_per_prompt=1, guidance_scale=guidance)
    guidance_images.append(my_output.images[0])
     
    for ii, img in enumerate(my_output.images):
        img.save(f"mice_{my_seed}_g{int(guidance*2)}_tea3_{ii}.jpg")

temp = grid_show(guidance_images, rows=3) #my_output.images, rows=3)
plt.show()
   


In [None]:
my_prompt = "😬 inadvertent 😬 latte art accidental latte art, cozy, detail, intricate, cafe, coffee"
seed_all(my_seed)
my_output = pipe(my_prompt, num_inference_steps=50, num_images_per_prompt=9, guidance_scale=9.0)
temp = grid_show(my_output.images, rows=3)
plt.show()

for ii, img in enumerate(my_output.images):
    img.save(f"latte_{my_seed}_{ii}.jpg")

my_prompt = "piece of toasted bread bears the likeness of a realistic human face on golden-brown burnt toast surface, shroud"
seed_all(my_seed)
my_output = pipe(my_prompt, num_inference_steps=50, num_images_per_prompt=9, guidance_scale=9.0)
temp = grid_show(my_output.images, rows=3)
plt.show()

for ii, img in enumerate(my_output.images):
    img.save(f"toast_{my_seed}_{ii}.jpg")

In [None]:
my_prompt = "Artist's impression of first astronaut on Mars giving a thumbs-up 👍 after discovering fungoid alien Martian life"\
        ", hyper-realistic, realism, retro-futuristic, intricate, detailed, golden hour"

In [None]:
seed_all(my_seed)

images = []
my_output = pipe(my_prompt, num_inference_steps=50, callback=callback_stash_latents, \
        callback_steps=6, num_images_per_prompt=1, guidance=8.0)

images.append(my_output.images[0])

In [None]:
temp = grid_show(images, rows=3)
plt.show()

In [None]:
my_output.images[0]

In [None]:
seed_all(my_seed)
my_output_astro = pipe(my_prompt, num_inference_steps=50, num_images_per_prompt=9)

In [None]:
temp = grid_show(my_output_astro.images)
plt.show()
for ii, img in enumerate(my_output_astro.images):
    img.save(f"human_astro_{my_seed}_{ii}.jpg")

In [None]:
seed_all(my_seed)
my_output_alien = pipe(my_prompt, num_inference_steps=50, num_images_per_prompt=9, \
        negative_prompt="human, astronaut, person, man, woman, Earthling")

In [None]:
temp = grid_show(my_output_alien.images)
plt.show()

for ii, img in enumerate(my_output_alien.images):
    img.save(f"nonhuman_astro_{my_seed}_{ii}.jpg")

In [None]:
pipe_img2img = AutoPipelineForImage2Image.from_pretrained(\
    "runwayml/stable-diffusion-v1-5", use_safetensors=True)

In [None]:

if not (os.path.exists("600px-TRAPPIST-1e_artist_impression_2018.png")):
    os.system("wget 'https://upload.wikimedia.org/wikipedia/commons/thumb/3/38/TRAPPIST-1e_artist_impression_2018.png/600px-TRAPPIST-1e_artist_impression_2018.png'")

init_image = Image.open("600px-TRAPPIST-1e_artist_impression_2018.png").resize((128,128)).resize((512,512))

init_image

In [None]:
seed_all(my_seed)
my_output_img2img = pipe_2(prompt=my_prompt, guidance_scale=8.0, num_images_per_prompt=9, image=init_image)

In [None]:
grid_show(my_output_img2img.images, rows=3)
plt.show()

for ii, img in enumerate(my_output_img2img.images):
    img.save(f"img2img_human_astro_{my_seed}_{ii}.jpg")

In [None]:
seed_all(my_seed)

trappist_prompt = "Artist's impression of TRAPPIST-1e, a rocky water-world exoplanet ocean-bearing world,"\
        " orbiting within the habitable (or Goldilocks) zone"\
        " of the ultracool dwarf star TRAPPIST-1"\
        "NASA, artist concept, art, reconstruction"

my_output_trappist1e = pipe_2(prompt=trappist_prompt, num_images_per_prompt=9, \
        image=init_image, guidance_scale=5.0, negative_prompt=my_negative_prompt)

grid_show(my_output_trappist1e.images)
plt.show()

for ii, img in enumerate(my_output_trappist1e.images):
    img.save(f"trappist1e_waterocean_{my_seed+count}_{ii}.jpg")

In [None]:
import os
# https://www.jpl.nasa.gov/news/chasing-oumuamua
if not (os.path.exists("imagesasteroid20180627Oumuamua.2e16d0ba.fill-400x400-c50.gif")):
    os.system("wget 'https://d2pn8kiwq2w21t.cloudfront.net/images/imagesasteroid20180627Oumuamua.2e16d0ba.fill-400x400-c50.gif'")

init_image = Image.open("imagesasteroid20180627Oumuamua.2e16d0ba.fill-400x400-c50.gif")

In [None]:
init_image.seek(80)
init_image = init_image.resize((512,512))


In [None]:
seed_all(my_seed)

oumuamua_prompt = "Interstellar object"\
        " Oumuamua is an elongated alien spacecraft"\
        " artist concept, reconstruction, realistic render, NASA/JPL-Caltech "
negative_prompt = "normal asteroid"

oumuamua = pipe_2(prompt=oumuamua_prompt, num_images_per_prompt=9, \
        image=init_image, guidance_scale=10.0, negative_prompt=my_negative_prompt)

grid_show(oumuamua.images)
plt.show()

for ii, img in enumerate(oumuamua.images):
    img.save(f"oumuamua_{my_seed}_{ii}.jpg")

In [None]:
# https://photojournal.jpl.nasa.gov/catalog/PIA04413

if not (os.path.exists("300px-NASA_Mars_Rover.jpg")):
    os.system("wget 'https://upload.wikimedia.org/wikipedia/commons/thumb/d/d8/NASA_Mars_Rover.jpg/300px-NASA_Mars_Rover.jpg'")

init_image = Image.open("300px-NASA_Mars_Rover.jpg").crop((0,0,256,256)).resize((512,512))

seed_all(my_seed)

rover_prompt = "Cute cartoon watercolor of NASA's Mars Opportunity rover, doing a good job on Mars"\
        ", cozy, space, NASA, watercolour, art"

rover_wc = pipe_2(prompt=rover_prompt, num_images_per_prompt=9, \
            image=init_image, guidance_scale=10.0)

grid_show(rover_wc.images)
plt.show()

for ii, img in enumerate(rover_wc.images):
    img.save(f"rover_wc_{my_seed}_{ii}.jpg")
    my_cmap = plt.get_cmap("plasma")
    
    fig, ax = plt.subplots(1,2, figsize=(8,4))
    ax[1].imshow(init_image)
    ax[1].set_title("Initial image")
    ax[0].imshow(img)
    ax[0].set_title("After diffusion (watercolor)")
    
    fig.text(.44, .35, "→", color=my_cmap(192), fontsize=128)
    
    for idx in range(2):
        ax[idx].set_yticklabels("")
        ax[idx].set_xticklabels("")
        
    plt.show()