# Installs, Imports and Loading the Model

In [None]:
!pip install diffusers==0.6 transformers

In [None]:
import torch
from torch import autocast
from diffusers import StableDiffusionPipeline
from typing import Callable, List, Optional, Union
import inspect
import numpy as np

In [None]:
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16).to("cuda")

Downloading:   0%|          | 0.00/543 [00:00<?, ?B/s]

Fetching 16 files:   0%|          | 0/16 [00:00<?, ?it/s]

Downloading:   0%|          | 0.00/342 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/4.63k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/608M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/209 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/209 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/572 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/246M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/525k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/472 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/788 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.06M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/772 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.72G [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/550 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/167M [00:00<?, ?B/s]

ftfy or spacy is not installed using BERT BasicTokenizer instead of ftfy.


In [None]:
!mkdir imgs

In [None]:
!mkdir imgs/homonym_duplication imgs/meaning_edit imgs/meaning_sum

# Function Definitions

## Vector Utility Functions

In [None]:
def project(a, b):
    bb_dotprod = torch.dot(b,b)
    ab_dotprod = torch.dot(a,b)
    if bb_dotprod != 0:
        coeff = (ab_dotprod/bb_dotprod)
    else:
        coeff = 0
    return coeff * b

In [None]:
def w_b(w, b):
    v_b = torch.zeros((768)).type(torch.HalfTensor).cuda()
    for j in range(len(b)):
        v_b += torch.dot(w,b[j]) * b[j]
    return v_b

def normal(v):
    return (1/torch.sqrt(torch.dot(v,v))) * v

In [None]:
def norm(v):
  return torch.sqrt(torch.dot(v,v))

In [None]:
def cosine_sim(a,b):
  return torch.dot(a,b)/(torch.sqrt(torch.dot(a,a))*torch.sqrt(torch.dot(b,b)))

## Getting Images

Edited version of the ```StableDiffusionPipeline```'s ```__call__()``` function that enables giving the text embedding directly as input.




In [None]:
def get_images(text_embeddings, pipe, img_name,prompt=None, negative_prompt=None,num_images_per_prompt=3):
    height = 512
    width = 512
    num_inference_steps = 50
    guidance_scale = 7.5
    eta = 0.0
    generator = None
    latents = None
    output_type="pil"
    return_dict = True
    callback= None
    callback_steps= 1
    batch_size =1
    with torch.no_grad():

        bs_embed, seq_len, _ = text_embeddings.shape
        text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
        text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)

        do_classifier_free_guidance = guidance_scale > 1.0
        if do_classifier_free_guidance:
            uncond_tokens: List[str]
            if negative_prompt is None:
                uncond_tokens = [""]
            elif type(prompt) is not type(negative_prompt):
                raise TypeError(
                    f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
                    f" {type(prompt)}."
                )
            elif isinstance(negative_prompt, str):
                uncond_tokens = [negative_prompt]
            elif batch_size != len(negative_prompt):
                raise ValueError(
                    f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
                    f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
                    " the batch size of `prompt`."
                )
            else:
                uncond_tokens = negative_prompt

            max_length = text_embeddings.shape[1]
            uncond_input = pipe.tokenizer(
                uncond_tokens,
                padding="max_length",
                max_length=max_length,
                truncation=True,
                return_tensors="pt",
            )
            uncond_embeddings = pipe.text_encoder(uncond_input.input_ids.to(pipe.device))[0]

            seq_len = uncond_embeddings.shape[1]
            uncond_embeddings = uncond_embeddings.repeat(batch_size, num_images_per_prompt, 1)
            uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)

            text_embeddings = torch.cat([uncond_embeddings, text_embeddings])

        latents_shape = (batch_size * num_images_per_prompt, pipe.unet.in_channels, height // 8, width // 8)
        latents_dtype = text_embeddings.dtype
        if latents is None:
            if pipe.device.type == "mps":
                latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
                    pipe.device
                )
            else:
                latents = torch.randn(latents_shape, generator=generator, device=pipe.device, dtype=latents_dtype)
        else:
            if latents.shape != latents_shape:
                raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
            latents = latents.to(pipe.device)

        pipe.scheduler.set_timesteps(num_inference_steps)

        timesteps_tensor = pipe.scheduler.timesteps.to(pipe.device)

        latents = latents * pipe.scheduler.init_noise_sigma

        accepts_eta = "eta" in set(inspect.signature(pipe.scheduler.step).parameters.keys())
        extra_step_kwargs = {}
        if accepts_eta:
            extra_step_kwargs["eta"] = eta

        for i, t in enumerate(pipe.progress_bar(timesteps_tensor)):
            latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
            latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, t)

            noise_pred = pipe.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample

            if do_classifier_free_guidance:
                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

            latents = pipe.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample

            if callback is not None and i % callback_steps == 0:
                callback(i, t, latents)

        latents = 1 / 0.18215 * latents
        image = pipe.vae.decode(latents).sample

        image = (image / 2 + 0.5).clamp(0, 1)

        image = image.cpu().permute(0, 2, 3, 1).float().numpy()

        has_nsfw_concept = None

        if output_type == "pil":
            image = pipe.numpy_to_pil(image)

        if not return_dict:
            print("NSFW")

        out=image

        for i in range(len(image)):
            image[i].save("imgs/"+img_name + "_"+str(i)+".png")

## Getting Encodings

In [None]:
def one_prompt_embed(prompt_1, pipe):
    text_inputs = pipe.tokenizer(
        prompt_1,
        padding="max_length",
        max_length=pipe.tokenizer.model_max_length,
        return_tensors="pt",
    )
    text_input_ids = text_inputs.input_ids

    text_embeddings_1 = pipe.text_encoder(text_input_ids.to(pipe.device))[0]
    
    return text_embeddings_1

In [None]:
def sum_embedding(prompt_1, prompt_2, pipe, weights=[0.5,0.5]):
    text_inputs = pipe.tokenizer(
        prompt_1,
        padding="max_length",
        max_length=pipe.tokenizer.model_max_length,
        return_tensors="pt",
    )
    text_input_ids = text_inputs.input_ids

    text_embeddings_1 = pipe.text_encoder(text_input_ids.to(pipe.device))[0]

    text_inputs = pipe.tokenizer(
        prompt_2,
        padding="max_length",
        max_length=pipe.tokenizer.model_max_length,
        return_tensors="pt",
    )
    text_input_ids = text_inputs.input_ids

    text_embeddings_2 = pipe.text_encoder(text_input_ids.to(pipe.device))[0]

    text_embeddings = (weights[0] * text_embeddings_1) + (weights[1]*text_embeddings_2)

    return text_embeddings

## Generate All Images for Experiments on Summing Encodings

In [None]:
def concept_sum(concept_1, concept_2, pipe, filename, weights=[0.5,0.5]):
    for i in range(10):
        get_images(sum_embedding(concept_1,concept_2, pipe,weights), pipe, filename+"_sum_"+str(i))
    for i in range(10):
        get_images(one_prompt_embed(concept_1, pipe), pipe, filename+"_1_"+str(i))
        get_images(one_prompt_embed(concept_2, pipe), pipe, filename+"_2_"+str(i))

## Find Meaning Directions

In [None]:
def get_svd(vectors_m, vectors_amb, n, model_dim=768):
    mus = [torch.zeros((model_dim)).cuda() for i in range(n)]

    for i in range(n):
        mus[i] = (1/2)*(vectors_m[i]+vectors_amb[i])

    subspace = torch.zeros((model_dim,model_dim)).cuda()

    for i in range(n):
        subspace += (1/2)*torch.outer(vectors_m[i] - mus[i],vectors_m[i]- mus[i])
        subspace += (1/2)*torch.outer(vectors_amb[i]- mus[i],vectors_amb[i]- mus[i])

    u_m, s_m, v = np.linalg.svd(subspace.detach().cpu(), full_matrices=True)
    return torch.tensor(u_m).type(torch.HalfTensor).cuda(), s_m

def find_vectors(w, sentences_1, sentences_2, sentences_amb, pipe, min_dim=20, threshold=0.9985, model_dim=768):
    n = len(sentences_1)
    vectors_1 = []
    vectors_2 = []
    vectors_amb = []
    for i in range(n):
        full_vec_1 = one_prompt_embed(sentences_1[i], pipe)
        w_idx = sentences_1[i].split(" ").index(w) + 1
        vec_1 = full_vec_1[:,w_idx,:].squeeze(0)
        vectors_1.append(vec_1)

        full_vec_2 = one_prompt_embed(sentences_2[i], pipe)
        w_idx = sentences_2[i].split(" ").index(w) + 1
        vec_2 = full_vec_2[:,w_idx,:].squeeze(0)
        vectors_2.append(vec_2)

        full_vec_amb = one_prompt_embed(sentences_amb[i], pipe)
        w_idx = sentences_amb[i].split(" ").index(w) + 1
        vec_amb = full_vec_amb[:,w_idx,:].squeeze(0)
        vectors_amb.append(vec_amb)

    u_1, s_1 = get_svd(vectors_1, vectors_amb, n, model_dim)
    u_2, s_2 = get_svd(vectors_2, vectors_amb, n, model_dim)

    dim = 0
    while dim < min_dim or sum(s_1[:dim])/sum(s_1) < threshold or sum(s_2[:dim])/sum(s_2) < threshold:
        dim += 1

    v_1 = torch.zeros((model_dim)).type(torch.HalfTensor).cuda()
    for j in range(dim):
        all_vals = [torch.dot(vectors_1[i] - vectors_amb[i], u_1[:,j]) for i in range(n)]
        all_vals.sort()
        v_1 += all_vals[n//2] * u_1[:,j]
    norm_v_1 = norm(v_1)
    v_1 = normal(v_1)

    v_2 = torch.zeros((model_dim)).type(torch.HalfTensor).cuda()
    for j in range(dim):
        all_vals = [torch.dot(vectors_2[i] - vectors_amb[i], u_2[:,j]) for i in range(n)]
        all_vals.sort()
        v_2 += all_vals[n//2] * u_2[:,j]
    norm_v_2 = norm(v_2)
    v_2 = normal(v_2)

    return v_1, v_2

## Editing Embedding

In [None]:
def edit_embed(orig_embed, meaning_1, meaning_2):
    # away from meaning_1, towards meaning_2
    dot_1 = torch.abs(torch.dot(orig_embed, meaning_1))
    dot_2 = torch.abs(torch.dot(orig_embed, meaning_2))

    orig_embed = orig_embed - project(orig_embed, meaning_2) + (dot_1 +dot_2) * (meaning_2)
    return orig_embed

## Generate All Images for Sense Editing Experiments

In [None]:
def edit_prompts(word, prompt_dict, sentences_1, sentences_2, sentences_amb, pipe, repeat=5):
    v_1, v_2 = find_vectors(word, sentences_1, sentences_2, sentences_amb, pipe)
    for prompt, filename in prompt_dict.items():
        orig_prompt = prompt
        orig_embed = one_prompt_embed(orig_prompt,pipe)
        idx = orig_prompt.split(" ").index(word) + 1

        embed_1 = orig_embed.detach().clone()
        embed_1[:,idx,:] = edit_embed(embed_1[:,idx,:].squeeze(0).clone(), v_2, v_1).clone()

        embed_2 = orig_embed.detach().clone()
        embed_2[:,idx,:] = edit_embed(embed_2[:,idx,:].squeeze(0).clone(), v_1, v_2).clone()

        for i in range(repeat):
            get_images(embed_1, pipe, filename + "sense_1_" + str(i))
            get_images(embed_2, pipe, filename + "sense_2_" + str(i))
            get_images(orig_embed, pipe, filename + "amb_" + str(i))

# Experiments

## Homonym Duplication

Note: Homonym duplication is rare in Stable Diffusion, so it may not necessarily occur in any of the generated images

In [None]:
for i in range(5):
    get_images(one_prompt_embed("a woman with a silk bow and arrow", pipe), pipe, "homonym_duplication/dup_bow_"+str(i))

In [None]:
for i in range(5):
    get_images(one_prompt_embed("tall cranes by the ocean",pipe), pipe, "homonym_duplication/dup_crane_"+str(i))

In [None]:
for i in range(5):
    get_images(one_prompt_embed("a crane by the ocean",pipe), pipe, "homonym_duplication/dup_crane_sea_"+str(i))

In [None]:
for i in range(10):
    get_images(one_prompt_embed("a bat and a baseball fly through the air",pipe), pipe, "homonym_duplication/neg_dup_bat_"+str(i),prompt="a bat and a baseball fly through the air", negative_prompt="disfigured, deformed, bad anatomy, low quality, jpeg artifacts")

In [None]:
for i in range(5):
    get_images(one_prompt_embed("a man with glasses",pipe), pipe, "homonym_duplication/dup_glasses_"+str(i))

In [None]:
for i in range(5):
    get_images(one_prompt_embed("a gentleman with a bow and arrow",pipe), pipe, "homonym_duplication/dup_bow_gent_"+str(i))

In [None]:
for i in range(5):
    get_images(one_prompt_embed("a baseball bat inside a spooky cave",pipe), pipe, "homonym_duplication/dup_bat_cave_"+str(i))

## Summing Encodings

In [None]:
concept_sum("tree", "cat", pipe, "meaning_sum/treecat")
concept_sum("dog", "lake", pipe, "meaning_sum/doglake")
concept_sum("bear", "waterfall", pipe, "meaning_sum/bearwaterfall")
concept_sum("bear", "hat", pipe, "meaning_sum/bearhat")
concept_sum("a wall painted red", "a wall painted blue", pipe, "meaning_sum/redbluewall")
concept_sum("a completely black cat", "a completely white cat", pipe, "meaning_sum/blackwhitecat")

## Editing Meaning

In [None]:
crane_sentence_animal = ["a flying crane", "there is a flying crane", 
                    "there is a hungry crane on the nature reserve", 
                    "a hungry crane hunts fish", 
                    "a boy feeds a hungry crane", 
                    "a feathered crane beside a nest", 
                    "a hungry crane is eating some fish", 
                    "a feathered crane in a nest"]

crane_sentence_construction = ["a tower crane", "there is a tower crane", 
                    "there is a tower crane on the building site", 
                    "a tower crane lifts loads", 
                    "a man operates a tower crane", 
                    "a tower crane beside a bulldozer", 
                    "a tower crane is lifting a container", 
                    "a tower crane in a quarry"]

crane_sentence_amb = ["a crane", "there is a crane", 
                    "there is a crane on the other side", 
                    "a crane is tall", 
                    "a boy sees a crane", 
                    "a crane beside a tree", 
                    "a crane is casting a shadow", 
                    "a crane by the ocean"]

bat_sentence_baseball = ["a baseball bat", "there is a baseball bat", 
                    "i play baseball with the bat", 
                    "to play baseball you need a bat and a ball", 
                    "the boy bought a baseball bat", 
                    "a baseball player throws a baseball bat", 
                    "a baseball bat is laying on the base", 
                    "a baseball bat in the store",
                    "a sports store sells a baseball bat"]

bat_sentence_animal = ["a vampire bat", "there is a vampire bat", 
                    "i feed insects to the vampire bat", 
                    "to celebrate halloween you need a vampire bat and a pumpkin", 
                    "the boy saw a vampire bat", 
                    "a wildlife expert feeds a vampire bat", 
                    "a vampire bat is hanging from the tree", 
                    "a vampire bat in the cave",
                    "a special zoo keeps a vampire bat"]

bat_sentence_amb = ["a bat","there is a bat", "i do things with the bat", 
                "to do anything you need a bat and something else", 
                "the person saw a bat",
                "a person mentions a bat",
                "a bat is rolling on the floor",
                "a bat in the place",
                "a location has a bat"]

In [None]:
edit_prompts("bat", {"a bat":"meaning_edit/bat_", "a bat and a baseball fly through the air":"meaning_edit/bat_fly_through_the_air_"}, bat_sentence_baseball, bat_sentence_animal, bat_sentence_amb, pipe)
edit_prompts("crane", {"a crane":"meaning_edit/crane_", "a crane by the ocean":"meaning_edit/crane_by_ocean_","a crane surrounded by nature":"meaning_edit/crane_nature_"}, crane_sentence_construction, crane_sentence_animal, crane_sentence_amb, pipe)

# Zip Images to Download

In [None]:
!zip -r imgs.zip imgs/ 

  adding: imgs/ (stored 0%)
  adding: imgs/homonym_duplication/ (stored 0%)
  adding: imgs/meaning_sum/ (stored 0%)
  adding: imgs/meaning_edit/ (stored 0%)
  adding: imgs/meaning_edit/bat_ball_sense_2_2_1.png (deflated 0%)
  adding: imgs/meaning_edit/bat_ball_amb_4_2.png (deflated 0%)
  adding: imgs/meaning_edit/bat_ball_sense_1_3_0.png (deflated 0%)
  adding: imgs/meaning_edit/bat_ball_amb_4_0.png (deflated 0%)
  adding: imgs/meaning_edit/bat_ball_sense_2_2_2.png (deflated 0%)
  adding: imgs/meaning_edit/bat_ball_sense_1_4_2.png (deflated 0%)
  adding: imgs/meaning_edit/bat_ball_sense_1_2_0.png (deflated 0%)
  adding: imgs/meaning_edit/bat_ball_sense_1_4_1.png (deflated 0%)
  adding: imgs/meaning_edit/bat_ball_sense_1_0_0.png (deflated 0%)
  adding: imgs/meaning_edit/bat_ball_sense_2_4_0.png (deflated 0%)
  adding: imgs/meaning_edit/bat_ball_sense_1_2_2.png (deflated 0%)
  adding: imgs/meaning_edit/bat_ball_sense_1_1_0.png (deflated 0%)
  adding: imgs/meaning_edit/bat_ball_amb_3_0.pn