In [None]:
# !nvidia-smi
# !pip install diffusers==0.3.0
# !pip install transformers scipy ftfy
# !pip install "ipywidgets>=7,<8"

import pdb
import torch
from diffusers import StableDiffusionPipeline
from tqdm.auto import tqdm
from torch import autocast
from PIL import Image
import pdb

YOUR_TOKEN="REPLACE WITH YOUR HUGGINGFACE TOKEN"
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=YOUR_TOKEN)  
pipe = pipe.to("cuda")

tokenizer, text_encoder, unet, scheduler, vae = pipe.tokenizer, pipe.text_encoder, pipe.unet, pipe.scheduler, pipe.vae

max_length = tokenizer.model_max_length
torch_device = 'cuda'
batch_size = 1


def get_text_embeddings(prompt):
    text_input = tokenizer(prompt, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt")
    text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0]
    
    return text_embeddings

def im_gen(text_embeddings, latents, num_inference_steps=200, guidance_scale=10): 
    uncond_input = tokenizer(
      [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
    )
    uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]
    text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
    # latents = latents * scheduler.init_noise_sigma
    latents = latents.to(torch_device) 

    scheduler.set_timesteps(num_inference_steps)
    with autocast("cuda"):
        for i, t in tqdm(enumerate(scheduler.timesteps)):
            # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
            latent_model_input = torch.cat([latents] * 2)
            # predict the noise residual
            with torch.no_grad():
                noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
                # perform guidance
                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
                # compute the previous noisy sample x_t -> x_t-1
                latents = scheduler.step(noise_pred, t, latents).prev_sample

    with torch.no_grad():
        latents = 1 / 0.18215 * latents
        image = vae.decode(latents).sample

    image = (image / 2 + 0.5).clamp(0, 1)
    image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
    images = (image * 255).round().astype("uint8")

    return Image.fromarray(images[0])

def transfer_score(emb0, emb_add, emb_minus, latents, strength = 1, 
                   num_inference_steps=200, guidance_scale=10, correction_scale = 1): 
    uncond_input = tokenizer(
      [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
    )
    uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]    

    # latents = latents * scheduler.init_noise_sigma
    latents = latents.to(torch_device) 

    scheduler.set_timesteps(num_inference_steps)
    t0 = int(num_inference_steps * strength)
    first_stage_steps = num_inference_steps - t0
    
    with autocast("cuda"):
        for i, t in tqdm(enumerate(scheduler.timesteps)):
            with torch.no_grad():
                if (i + 1) <= first_stage_steps:
                    latent_model_input = torch.cat([latents] * 2)
                    noise_pred = unet(latent_model_input, t, 
                                      encoder_hidden_states=torch.cat([uncond_embeddings, emb0])).sample                
                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                else:
                    latent_model_input = torch.cat([latents] * 4)
                    noise_pred = unet(latent_model_input, t, 
                                      encoder_hidden_states=torch.cat([uncond_embeddings, emb0, emb_add, emb_minus])).sample                
                    noise_pred_uncond, noise_pred_text, noise_pred_text_plus, noise_pred_text_minus = noise_pred.chunk(4)
                    noise_pred_text += (noise_pred_text_plus - noise_pred_text_minus) * correction_scale
                    
                noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
                latents = scheduler.step(noise_pred, t, latents).prev_sample
            
    with torch.no_grad():
        latents = 1 / 0.18215 * latents
        image = vae.decode(latents).sample

    image = (image / 2 + 0.5).clamp(0, 1)
    image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
    images = (image * 255).round().astype("uint8")

    return Image.fromarray(images[0])

def transfer_prompt(emb0, emb_add, emb_minus, latents, 
                    num_inference_steps = 200, guidance_scale = 10, strength = 1):
    emb = emb0 - emb_minus + emb_add
    emb /= torch.sqrt((emb**2).sum())
    emb *= torch.sqrt((emb0**2).sum())
    
    uncond_input = tokenizer(
      [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
    )
    uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]
    
    # latents = latents * scheduler.init_noise_sigma
    latents = latents.to(torch_device) 

    scheduler.set_timesteps(num_inference_steps)
    
    t0 = int(num_inference_steps * strength)
    first_stage_steps = num_inference_steps - t0
    
    with autocast("cuda"):
        for i, t in tqdm(enumerate(scheduler.timesteps)):
            with torch.no_grad():
                if (i + 1) <= first_stage_steps:
                    emb_ = emb0
                else:
                    emb_ = emb
                latent_model_input = torch.cat([latents] * 2)
                noise_pred = unet(latent_model_input, t, 
                                  encoder_hidden_states=torch.cat([uncond_embeddings, emb_])).sample                
                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                        
                noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
                latents = scheduler.step(noise_pred, t, latents).prev_sample
            
    with torch.no_grad():
        latents = 1 / 0.18215 * latents
        image = vae.decode(latents).sample

    image = (image / 2 + 0.5).clamp(0, 1)
    image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
    images = (image * 255).round().astype("uint8")

    return Image.fromarray(images[0])


def concept_proj(emb0, emb_z0, emb_z1, emb_z_target,  latents0, 
                 strength = 1, num_inference_steps=200, guidance_scale=10): 
    uncond_input = tokenizer(
      [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
    )
    uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]
    

    latents = latents0.clone()
    # latents = latents * scheduler.init_noise_sigma
    latents = latents.to(torch_device) 

    scheduler.set_timesteps(num_inference_steps)
    
    t0 = int(num_inference_steps * strength)
    t0 = min(t0, num_inference_steps)
    first_stage_steps = num_inference_steps - t0
    

    with autocast("cuda"):
        for i, t in tqdm(enumerate(scheduler.timesteps)):
            with torch.no_grad():
                if (i + 1) <= first_stage_steps:
                    latent_model_input = torch.cat([latents] * 2)
                    noise_pred = unet(latent_model_input, t, 
                                      encoder_hidden_states=torch.cat([uncond_embeddings, emb0])).sample                
                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                else:
                    latent_model_input = torch.cat([latents] * 5)
                    noise_pred = unet(latent_model_input, t, 
                                      encoder_hidden_states=torch.cat([uncond_embeddings, emb0, emb_z0, emb_z1, emb_z_target])).sample                
                    noise_pred_uncond, noise_pred_text0, noise_pred_text_z0, noise_pred_text_z1, noise_pred_text_z_target = noise_pred.chunk(5)
                    
                    ## score difference
                    noise_tmp = noise_pred_text0 - noise_pred_text_z_target                    
                    ## Z direction
                    u = noise_pred_text_z1 - noise_pred_text_z0
                    u /= torch.sqrt((u**2).sum())
                    ## project out Z direction
                    noise_pred_text0 -= (noise_tmp * u).sum() * u
                    
                noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text0 - noise_pred_uncond)
                latents = scheduler.step(noise_pred, t, latents).prev_sample
            
    with torch.no_grad():
        latents = 1 / 0.18215 * latents
        image = vae.decode(latents).sample

    image = (image / 2 + 0.5).clamp(0, 1)
    image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
    images = (image * 255).round().astype("uint8")
    
    return Image.fromarray(images[0])

    

def image_grid(imgs, rows, cols):
    w, h = imgs[0].size
    grid = Image.new('RGB', size=(cols*w, rows*h))
    grid_w, grid_h = grid.size
    
    for i, img in enumerate(imgs):
        grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid


def experiment0(prompt, name, num_inference_steps, guidance_scale, 
               N, nrow, ncol, height = 512, width = 512, seeds = None):
    ims = []
    emb0 = get_text_embeddings(prompt) 
    if seeds is None:
        seeds = [i for i in range(N)]
    for i in seeds:
        generator = torch.manual_seed(i)
        lat = torch.randn(
         (1, unet.in_channels, height // 8, width // 8),
         generator=generator,
        )
        ims.append(im_gen(emb0, lat, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale))

    grid = image_grid(ims, rows=nrow, cols=ncol)
    grid.save(f"{name}.png")
        
def experiment_proj(prompts, name, num_inference_steps, guidance_scale, 
               N, nrow, ncol, height = 512, width = 512, seeds = None):
    ims_pj = []

    emb0 = get_text_embeddings(prompts[0]) 
    emb_z0 = get_text_embeddings(prompts[1]) 
    emb_z1 = get_text_embeddings(prompts[2]) 
    emb_z_target = get_text_embeddings(prompts[3]) 

    if seeds is None:
        seeds = [i for i in range(N)]
    for i in seeds:
        generator = torch.manual_seed(i)
        lat = torch.randn(
         (1, unet.in_channels, height // 8, width // 8),
         generator=generator,
        )
        ims_pj.append(concept_proj(emb0, emb_z0, emb_z1, emb_z_target, lat, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale))

    grid = image_grid(ims_pj, rows=nrow, cols=ncol)
    grid.save(f"{name}.png")
    
def experiment_transfer_score(prompts, name, num_inference_steps, guidance_scale, 
               N, nrow, ncol, s = 1, height = 512, width = 512, seeds = None):
    ims = []
    emb0 = get_text_embeddings(prompts[0]) 
    emb_plus = get_text_embeddings(prompts[1]) 
    emb_minus = get_text_embeddings(prompts[2]) 

    if seeds is None:
        seeds = [i for i in range(N)]
    for i in seeds:
        generator = torch.manual_seed(i)
        lat = torch.randn(
         (1, unet.in_channels, height // 8, width // 8),
         generator=generator,
        )
        ims.append(transfer_score(emb0, emb_plus, emb_minus, lat, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, strength=s))
    grid = image_grid(ims, rows=nrow, cols=ncol)
    grid.save(f"{name}.png")
    
def experiment_transfer_prompt(prompts, name, num_inference_steps, guidance_scale, 
               N, nrow, ncol, s = 1, height = 512, width = 512, seeds = None):
    ims = []
    emb0 = get_text_embeddings(prompts[0]) 
    emb_plus_prompt = get_text_embeddings(prompts[1]) 
    emb_minus_prompt = get_text_embeddings(prompts[2]) 

    if seeds is None:
        seeds = [i for i in range(N)]
    for i in seeds:
        generator = torch.manual_seed(i)
        lat = torch.randn(
         (1, unet.in_channels, height // 8, width // 8),
         generator=generator,
        )
        ims.append(transfer_prompt(emb0, emb_plus_prompt, emb_minus_prompt, lat, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, strength=s))

    grid = image_grid(ims, rows=nrow, cols=ncol)
    grid.save(f"{name}.png")

# Concept Transfer

## King to Queen

In [None]:
## KING-QUEEN example
guidance_scale = 15
num_inference_steps = 50
N, nrow, ncol = 5, 1, 5

experiment0("the portrait of a king", 'king', num_inference_steps, guidance_scale, N, nrow, ncol)

prompts = ["the portrait of a king", "the portrait of a woman", "the portrait of a man"]
experiment_transfer_score(prompts, "king2queen_score", num_inference_steps, guidance_scale, N, nrow, ncol)

prompts = ["the portrait of a king", "the portrait of a woman", "the portrait of a man"]
experiment_transfer_prompt(prompts, "king2queen_prompt", num_inference_steps, guidance_scale, N, nrow, ncol)

## Frog, cartoon to photorealistic

First, we generate 50 images for different styles. From them we selected a few for deomonstrations in the paper. 

In [None]:
# ## FROG, direct prompting, full
guidance_scale = 15
num_inference_steps = 50
N, nrow, ncol = 50, 5, 10

names = ['frog_nostyle', 'frog_cartoon', 'frog_photo']
prompts = ['a frog playing the piano, anthropomorphic',
          'a frog playing the piano, anthropomorphic, cartoon',
          'a frog playing the piano, anthropomorphic, photorealistic']

experiment0(prompts[0], names[0], num_inference_steps, guidance_scale, N, nrow, ncol)
experiment0(prompts[1], names[1], num_inference_steps, guidance_scale, N, nrow, ncol)
experiment0(prompts[2], names[2], num_inference_steps, guidance_scale, N, nrow, ncol)

name = 'frog_TransferScore'
prompts = ['a frog playing the piano, anthropomorphic, cartoon',
          'photorealistic','cartoon']
experiment_transfer_score(prompts, name, num_inference_steps, guidance_scale, N, nrow, ncol)

name = 'frog_TransferPrompt_bad'
experiment_transfer_prompt(prompts, name, num_inference_steps, guidance_scale, N, nrow, ncol)

name = 'frog_TransferPrompt_good'
prompts = ['a frog playing the piano, anthropomorphic, cartoon',
           'a man playing the piano, anthropomorphic, photorealistic',
           'a man playing the piano, anthropomorphic, cartoon']
experiment_transfer_prompt(prompts, name, num_inference_steps, guidance_scale, N, nrow, ncol)

In [None]:
seeds = [5, 9, 15, 20, 23, 29, 30, 39, 44, 46]

N, nrow, ncol = 10, 1, 10
names = ['frog_nostyle_selected', 'frog_cartoon_selected', 'frog_photo_selected']
experiment0(prompts[0], names[0], num_inference_steps, guidance_scale, N, nrow, ncol, seeds = seeds)
experiment0(prompts[1], names[1], num_inference_steps, guidance_scale, N, nrow, ncol, seeds = seeds)
experiment0(prompts[2], names[2], num_inference_steps, guidance_scale, N, nrow, ncol, seeds = seeds)

name = 'frog_selected_TransferScore'
prompts = ['a frog playing the piano, anthropomorphic, cartoon',
          'photorealistic','cartoon']
experiment_transfer_score(prompts, name, num_inference_steps, guidance_scale, N, nrow, ncol, seeds = seeds)

name = 'frog_selected_TransferPrompt_bad'
experiment_transfer_prompt(prompts, name, num_inference_steps, guidance_scale, N, nrow, ncol, seeds = seeds)

name = 'frog_selected_TransferPrompt_good'
prompts = ['a frog playing the piano, anthropomorphic, cartoon',
           'a man playing the piano, anthropomorphic, photorealistic', 
           'a man playing the piano, anthropomorphic, cartoon']
experiment_transfer_prompt(prompts, name, num_inference_steps, guidance_scale, N, nrow, ncol, seeds = seeds)

In [None]:
guidance_scale = 15
num_inference_steps = 50
seeds = [15, 20, 39, 42, 43]
N, nrow, ncol = 5, 1, 5

prompts = ['a frog playing the piano, anthropomorphic',
          'a frog playing the piano, anthropomorphic, cartoon',
          'a frog playing the piano, anthropomorphic, photorealistic']

names = ['frog_nostyle_5', 'frog_cartoon_5', 'frog_photo_5']
experiment0(prompts[0], names[0], num_inference_steps, guidance_scale, N, nrow, ncol, seeds = seeds)
experiment0(prompts[1], names[1], num_inference_steps, guidance_scale, N, nrow, ncol, seeds = seeds)
experiment0(prompts[2], names[2], num_inference_steps, guidance_scale, N, nrow, ncol, seeds = seeds)

name = 'frog_5_TransferScore'
prompts = ['a frog playing the piano, anthropomorphic, cartoon',
          'photorealistic','cartoon']
experiment_transfer_score(prompts, name, num_inference_steps, guidance_scale, N, nrow, ncol, seeds = seeds)

name = 'frog_5_TransferPrompt_bad'
prompts = ['a frog playing the piano, anthropomorphic, cartoon',
           'photorealistic', 'cartoon']
experiment_transfer_prompt(prompts, name, num_inference_steps, guidance_scale, N, nrow, ncol, seeds = seeds)

name = 'frog_5_TransferPrompt_good'
prompts = ['a frog playing the piano, anthropomorphic, cartoon',
           'a man playing the piano, anthropomorphic, photorealistic', 
           'a man playing the piano, anthropomorphic, cartoon']
experiment_transfer_prompt(prompts, name, num_inference_steps, guidance_scale, N, nrow, ncol, seeds = seeds)

## Failure example: nurse, deer

In [None]:
## NURSE-DEER example
guidance_scale = 10
num_inference_steps = 50
N, nrow, ncol = 10, 2, 5

experiment0("a nurse sitting in a white room", 'nurse_direct_short', num_inference_steps, guidance_scale, N, nrow, ncol)

prompts = ["a nurse sitting in a white room", "a buck on the grass", "a doe on the grass"]
experiment_transfer_score(prompts, "nurse_deer_failure", num_inference_steps, guidance_scale, N, nrow, ncol)

## analyze reasons of failure
names = ['buck', 'doe']
prompts = ["a buck on the grass", "a doe on the grass"]
experiment0(prompts[0], names[0], num_inference_steps, guidance_scale, N, nrow, ncol)
experiment0(prompts[1], names[1], num_inference_steps, guidance_scale, N, nrow, ncol)

## Failure example: dog, renaissance, man

In [None]:
## Dog-Renaissance example
guidance_scale = 10
num_inference_steps = 50
N, nrow, ncol = 10, 2, 5

experiment0("a dog sitting on the beach, cartoon", 'dog_cartoon', num_inference_steps, guidance_scale, N, nrow, ncol)

prompts = ["a dog sitting on the beach, cartoon", "a man, renaissance-style painting", "a man, cartoon"]
experiment_transfer_score(prompts, "dog_cartoon2renaissance_failure", num_inference_steps, guidance_scale, N, nrow, ncol)

## analyze reasons of failure
names = ['man_cartoon', 'man_renaissance']
prompts = ["a man, cartoon", "a man, renaissance-style painting"]
experiment0(prompts[0], names[0], num_inference_steps, guidance_scale, N, nrow, ncol)
experiment0(prompts[1], names[1], num_inference_steps, guidance_scale, N, nrow, ncol)

# Concept Projection

## labrador

In [None]:
guidance_scale = 10
N, nrow, ncol = 20, 2, 10
num_inference_steps = 50
experiment0("a labrador", 'labrador_plain', num_inference_steps, guidance_scale, N, nrow, ncol)

experiment0("a baby labrador on the grass", 'labrador_direct', num_inference_steps, guidance_scale, N, nrow, ncol)
prompts = ["a baby labrador on the grass", 
           "a light-colored labrador", "a dark-colored labrador", 
           "a labrador"]
experiment_proj(prompts, "labrador_proj", num_inference_steps, guidance_scale, N, nrow, ncol)

## nurse

For this example, we find setting more steps results in images of better qualities. Therefore we choose 500 steps. 

In [None]:
## nurse
guidance_scale = 10
N, nrow, ncol = 20, 2, 10
num_inference_steps = 500
experiment0("a nurse sitting in a white room", 'nurse_direct', num_inference_steps, guidance_scale, N, nrow, ncol)
prompts = ["a nurse sitting in a white room", "a woman", "a man", "a person"]
experiment_proj(prompts, "nurse_proj", num_inference_steps, guidance_scale, N, nrow, ncol)

In [None]:
## nurse
guidance_scale = 10
N, nrow, ncol = 20, 2, 10
num_inference_steps = 500

name = 'nurse_TransferPrompt_bad'
prompts = ['a nurse sitting in a white room',
           'person', 'woman']
experiment_transfer_prompt(prompts, name, num_inference_steps, guidance_scale, N, nrow, ncol)

name = 'nurse_TransferPrompt_good'
prompts = ['a nurse sitting in a white room',
           'a person sitting in a white room', 
           'a woman sitting in a white room']
experiment_transfer_prompt(prompts, name, num_inference_steps, guidance_scale, N, nrow, ncol)

# mathematician

In [None]:
## mathematician
guidance_scale = 10
N, nrow, ncol = 20, 2, 10
num_inference_steps = 50
experiment0("a person", 'person_plain', num_inference_steps, guidance_scale, N, nrow, ncol)
experiment0("a portrait of a mathematician", 'mathematician_direct', num_inference_steps, guidance_scale, N, nrow, ncol)

prompts = ["a portrait of a mathematician", "a woman", "a man", "a person"]
experiment_proj(prompts, "mathematician_proj", num_inference_steps, guidance_scale, N, nrow, ncol)