In [None]:
from google.colab import drive
drive.mount('/content/drive')

import sys
sys.path.append('/content/drive/MyDrive/cs229/sd_data/')
%cd drive/MyDrive/cs229/sd_data

In [None]:
!pip install -qq diffusers==0.8.0 transformers ftfy
!pip install -qq "ipywidgets>=7,<8"

In [None]:
import pickle
with open("all_added_emb.pickle", "rb") as f:
    all_added_emb = pickle.load(f)

In [None]:
import cv2
import os 

from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn import preprocessing

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from tqdm import tqdm 
from matplotlib.colors import ListedColormap

In [None]:
from huggingface_hub import HfApi
import requests

api = HfApi()

In [None]:
import torch
from diffusers import StableDiffusionPipeline

from huggingface_hub import notebook_login

notebook_login()

In [None]:
def get_2d_tsne(vecs, num_pca_comp = 75, perplexity = 40):
    pca = PCA(n_components=num_pca_comp).fit_transform(vecs)
    tsne = TSNE(n_components=2, perplexity=perplexity).fit_transform(pca)

    tx, ty = tsne[:,0], tsne[:,1]
    tx = (tx-np.min(tx)) / (np.max(tx) - np.min(tx))
    ty = (ty-np.min(ty)) / (np.max(ty) - np.min(ty))
    return tx, ty

### Visual tsne

In [None]:
words, vecs, imgs = [], [], []

for word in all_added_emb:
    img = cv2.imread(os.path.join(os.getcwd(), f"sd-concepts-library/{word}/0.jpeg"))
    if img is not None:
        imgs.append(os.path.join(os.getcwd(), f"sd-concepts-library/{word}/0.jpeg"))
        words.append(word)
        key = list(all_added_emb[word].keys())[0]
        vecs.append(all_added_emb[word][key].detach().numpy())
vecs = np.array(vecs)

In [None]:
%%capture
pca = PCA(n_components=300).fit_transform(vecs)
tsne = TSNE(n_components=2, perplexity=30, angle=0.2).fit_transform(pca)
tx, ty = tsne[:,0], tsne[:,1]
tx = (tx-np.min(tx)) / (np.max(tx) - np.min(tx))
ty = (ty-np.min(ty)) / (np.max(ty) - np.min(ty))

In [None]:
width = 4000
height = 3000
max_dim = 100

full_image = Image.new('RGBA', (width, height))
for img, x, y in zip(imgs, tx, ty):
    tile = Image.open(img)
    rs = max(1, tile.width/max_dim, tile.height/max_dim)
    tile = tile.resize((int(tile.width/rs), int(tile.height/rs)), Image.ANTIALIAS)
    full_image.paste(tile, (int((width-max_dim)*x), int((height-max_dim)*y)), mask=tile.convert('RGBA'))

plt.figure(figsize = (16,12))
plt.imshow(full_image)
full_image.save("tSNE_full_image.png")

In [None]:
import json
tsne_data = [{"path":os.path.abspath(img), "x":float(x), "y": float(y), "word":word}
             for img, x, y, word in zip(imgs, tx, ty, words)]
with open("tSNE_data.json", 'w') as f:
    json.dump(tsne_data, f)

In [None]:
!git clone https://github.com/Quasimondo/RasterFairy.git
%cd RasterFairy/
!pip install .
%cd ..

In [None]:
import rasterfairy
nx, ny = 26, 26

grid_assignment = rasterfairy.transformPointCloud2D(tsne)
tile_width = 50
tile_height = 50

full_width = tile_width * nx
full_height = tile_height * ny
aspect_ratio = float(tile_width) / tile_height

grid_image = Image.new('RGBA', (full_width, full_height))

for img, grid_pos in zip(imgs, grid_assignment[0]):
    idx_x, idx_y = grid_pos
    x, y = tile_width * idx_x, tile_height * idx_y
    tile = Image.open(img)
    tile_ar = float(tile.width) / tile.height
    if (tile_ar > aspect_ratio):
        margin = 0.5 * (tile.width - aspect_ratio * tile.height)
        tile = tile.crop((margin, 0, margin + aspect_ratio * tile.height, tile.height))
    else:
        margin = 0.5 * (tile.height - float(tile.width) / aspect_ratio)
        tile = tile.crop((0, margin, tile.width, margin + float(tile.width) / aspect_ratio))
    tile = tile.resize((tile_width, tile_height), Image.ANTIALIAS)
    grid_image.paste(tile, (int(x), int(y)))

plt.figure(figsize = (16,12))
plt.imshow(grid_image)

## Retrieve object type info -> plot by style vs. object

In [None]:
def get_concept_type(model_obj):
  concept_type = f"https://huggingface.co/{model_obj.modelId}/raw/main/type_of_concept.txt"
  response = requests.get(concept_type)
  return response.text

models_list = api.list_models(author="sd-concepts-library")
get_concept_type(models_list[0])

In [None]:
model_id_to_type = {}
for x in tqdm(models_list):
  model_id_to_type[x.modelId] = get_concept_type(x)

In [None]:
word_to_type = {}
for model_id in model_id_to_type:
  word = model_id.split('/')[-1]
  word_to_type[word] = model_id_to_type[model_id]

In [None]:
words, vecs = [], []

for word in all_added_emb:
    if word in word_to_type:
        words.append(word)
        key = list(all_added_emb[word].keys())[0]
        vecs.append(all_added_emb[word][key].detach().numpy())
    else:
        print("Warning: skipping", word)
vecs = np.array(vecs)

In [None]:
colors = [word_to_type[word] for word in words]
color_map = {'style': 0, 'object': 1}
colors = [color_map[c] for c in colors]

In [None]:
scaler = preprocessing.StandardScaler().fit(vecs)
vecs_normed = scaler.transform(vecs)

In [None]:
normed_tx, normed_ty = get_2d_tsne(vecs_normed, num_pca_comp = 75, perplexity = 14)

scatter = plt.scatter(normed_tx, normed_ty, c=colors, cmap=ListedColormap(['r','g']), s=1.5)

plt.legend(handles=scatter.legend_elements()[0], labels=['style', 'object'])
plt.title(f't-SNE of concepts by type')
# plt.title(f't-SNE visualization with perplexity = {perplexity}, num_pca_comp = {num_pca_comp}')
plt.savefig("concept_type_tsne.png", dpi=300)
plt.show()

In [None]:
tx, ty = get_2d_tsne(vecs)
scatter = plt.scatter(tx, ty, c=colors, cmap=ListedColormap(['r','b']), s=1.5)

plt.legend(handles=scatter.legend_elements()[0], labels=['style', 'object'])
plt.title(f't-SNE visualization of textual inversion concepts by type')
# plt.title(f't-SNE visualization with perplexity = {perplexity}, num_pca_comp = {num_pca_comp}')
plt.savefig("concept_type_tsne.png", dpi=300)
plt.show()

## Utilities

In [None]:
def load_words_and_vecs():
    words, vecs = [], []

    for word in all_added_emb:
        words.append(word)
        key = list(all_added_emb[word].keys())[0]
        vecs.append(all_added_emb[word][key].detach().numpy())

    vecs = np.array(vecs)

    return words, vecs

## tSNE with normal word embeddings

In [None]:
!pip install accelerate

In [None]:
model_id = "CompVis/stable-diffusion-v1-4"
device = "cuda"

pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, revision="fp16")
pipe = pipe.to(device)

In [None]:
words, vecs = load_words_and_vecs()

In [None]:
# retrieve word embeddings from model (about 50k, get 1% of them)
model_token_vecs = pipe.text_encoder.get_input_embeddings().weight.data.cpu()
print(model_token_vecs.shape)
selected_model_token_vecs = model_token_vecs
selected_model_token_vecs = selected_model_token_vecs[::100]
print(selected_model_token_vecs.shape)

In [None]:
def plot_words_vs_concepts(all_vecs, colors_to_use, color_labels, use_tsne=True):
    if use_tsne:
        token_vs_concept_tx, token_vs_concept_ty = get_2d_tsne(all_vecs)
    else:
        pca = PCA(n_components=2).fit_transform(all_vecs)
        token_vs_concept_tx = pca[:, 0]
        token_vs_concept_ty = pca[:, 1]

    scatter = plt.scatter(
        token_vs_concept_tx, 
        token_vs_concept_ty, 
        c=colors_to_use, 
        cmap=ListedColormap(['r','b', 'g']), s=1.5)

    plt.legend(handles=scatter.legend_elements()[0], labels=color_labels)
    if use_tsne:
        plt.title(f't-SNE of concepts vs. existing tokens')
    else:
        plt.title(f'PCA of concepts vs. existing tokens')
    # plt.title(f't-SNE visualization with perplexity = {perplexity}, num_pca_comp = {num_pca_comp}')

    if not use_tsne:
        plt.xlim((-2, 1.25))
        plt.ylim((-2, 2))

    plt.savefig("word_vs_concept_tsne.png", dpi=300)
    plt.show()

    return token_vs_concept_tx, token_vs_concept_ty

In [None]:
all_vecs = np.concatenate([vecs, selected_model_token_vecs])
print(all_vecs.shape)

new_colors = np.concatenate((np.ones(vecs.shape[0]), np.zeros(selected_model_token_vecs.shape[0])))
new_color_labels = ['existing tokens', 'added concepts']

tx, ty = plot_words_vs_concepts(all_vecs, new_colors, new_color_labels)
px, py = plot_words_vs_concepts(all_vecs, new_colors, new_color_labels, use_tsne=False)

In [None]:
# some of the word vectors in selected_model_token_vecs have 0 magnitude and need to be removed 
nonzero_model_token_vecs = selected_model_token_vecs[
    np.where(np.linalg.norm(selected_model_token_vecs, axis=-1) != 0)
]

In [None]:
# # make them unit vectors
# all_vecs_normed = np.concatenate([vecs, nonzero_model_token_vecs])
# all_vecs_normed = all_vecs_normed / np.linalg.norm(all_vecs_normed, axis=-1)[:, None]

# new_colors = np.concatenate((np.ones(vecs.shape[0]), np.zeros(nonzero_model_token_vecs.shape[0])))
# new_color_labels = ['existing tokens', 'added concepts']

# normed_tx, normed_ty = plot_words_vs_concepts(all_vecs_normed, new_colors, new_color_labels)
# normed_px, normed_py = plot_words_vs_concepts(all_vecs_normed, new_colors, new_color_labels, use_tsne=False)

In [None]:
# # normalize so that each feature has 0 mean and unit variance
# # normalize words and concepts separately first, then together 
# vecs_normed = (vecs - np.mean(vecs, axis=0)) / np.std(vecs, axis=0)
# nonzero_model_token_vecs_normed = (nonzero_model_token_vecs - torch.mean(nonzero_model_token_vecs, axis=0)) / torch.std(nonzero_model_token_vecs, axis=0)

# all_vecs_normed = np.concatenate([vecs_normed, nonzero_model_token_vecs_normed])

# # all_vecs_normed = np.concatenate([vecs, nonzero_model_token_vecs])
# all_vecs_normed = (all_vecs_normed - np.mean(all_vecs_normed, axis=0)) / np.std(all_vecs_normed, axis=0)

# new_colors = np.concatenate((np.ones(vecs.shape[0]), np.zeros(nonzero_model_token_vecs.shape[0])))
# new_color_labels = ['existing tokens', 'added concepts']

# normed_tx, normed_ty = plot_words_vs_concepts(all_vecs_normed, new_colors, new_color_labels)
# normed_px, normed_py = plot_words_vs_concepts(all_vecs_normed, new_colors, new_color_labels, use_tsne=False)

In [None]:
# normalize so that each feature has 0 mean and unit variance
all_vecs_normed = np.concatenate([vecs, nonzero_model_token_vecs])
all_vecs_normed = (all_vecs_normed - np.mean(all_vecs_normed, axis=0)) / np.std(all_vecs_normed, axis=0)

new_colors = np.concatenate((np.ones(vecs.shape[0]), np.zeros(nonzero_model_token_vecs.shape[0])))
new_color_labels = ['existing tokens', 'added concepts']

normed_tx, normed_ty = plot_words_vs_concepts(all_vecs_normed, new_colors, new_color_labels)
normed_px, normed_py = plot_words_vs_concepts(all_vecs_normed, new_colors, new_color_labels, use_tsne=False)

In [None]:
# normalize so that each feature has 0 mean and unit variance
all_vecs = np.concatenate([vecs, nonzero_model_token_vecs])

from sklearn import preprocessing
scaler = preprocessing.StandardScaler().fit(all_vecs)
all_vecs_normed = scaler.transform(all_vecs)

new_colors = np.concatenate((np.ones(vecs.shape[0]), np.zeros(nonzero_model_token_vecs.shape[0])))
new_color_labels = ['existing tokens', 'added concepts']

normed_tx, normed_ty = plot_words_vs_concepts(all_vecs_normed, new_colors, new_color_labels)
normed_px, normed_py = plot_words_vs_concepts(all_vecs_normed, new_colors, new_color_labels, use_tsne=False)

In [None]:
colors = []
for word in words:
  if word not in word_to_type:
    print('missing')
    colors.append(3)
  elif word_to_type[word] == 'object':
    colors.append(1)
  else:
    colors.append(2)
len(colors)

In [None]:
# normalize so that each feature has 0 mean and unit variance
all_vecs = np.concatenate([vecs, nonzero_model_token_vecs])

scaler = preprocessing.StandardScaler().fit(all_vecs)
all_vecs_normed = scaler.transform(all_vecs)

new_colors = np.concatenate((np.array(colors), np.zeros(nonzero_model_token_vecs.shape[0])))
new_color_labels = ['existing tokens', 'added concepts (object)', 'added concepts (style)', 'added concepts (unknown)']

normed_tx, normed_ty = plot_words_vs_concepts(all_vecs_normed, new_colors, new_color_labels)
normed_px, normed_py = plot_words_vs_concepts(all_vecs_normed, new_colors, new_color_labels, use_tsne=False)

In [None]:
# concept_vec_norms = np.linalg.norm(vecs, axis=-1)
# word_vec_norms = np.linalg.norm(selected_model_token_vecs, axis=-1)
# np.median(word_vec_norms)

#### K means to find concept vectors similar to words (abandoned)

In [None]:
num_samples = all_vecs_normed.shape[0]
kmeans_input_data = np.zeros((num_samples, 2))
kmeans_input_data[:, 0] = normed_tx
kmeans_input_data[:, 1] = normed_ty 

In [None]:
from sklearn.cluster import KMeans

kmeans = KMeans(init="k-means++", n_clusters=2, n_init=4)
kmeans.fit(kmeans_input_data)

In [None]:
cluster_pred = kmeans.predict(kmeans_input_data)

In [None]:
scatter = plt.scatter(
    normed_tx, 
    normed_ty, 
    c=cluster_pred, 
    cmap=ListedColormap(['r','b', 'g']), s=1.5)

plt.legend(handles=scatter.legend_elements()[0], labels=['cluster 1', 'cluster 2'])
plt.title(f't-SNE visualization of concepts and words (K-means)')
plt.show()

In [None]:
len(np.where(new_colors == cluster_pred)[0])

In [None]:
len(np.where(new_colors != cluster_pred)[0])

In [None]:
# indices where points are in the other cluster
idces = np.where(new_colors != cluster_pred)[0]
for idx in idces:
    if idx < vecs.shape[0]: # added concept that clusters with words
        print('concept', idx, words[idx])
    else: # word that clusters with added concepts
        print('word', idx)
        real_idx = idx - vecs.shape[0]
        print(pipe.tokenizer.convert_ids_to_tokens([real_idx * 5]))

## Embedding arithmetic

In [None]:
images = pipe("A photo of <anime-girl>", num_images_per_prompt=2, num_inference_steps=50, guidance_scale=7.5)["images"]
for img in images:
    plt.imshow(img)
    plt.show()
print(words[421])


In [None]:
word_pairs = [
    ('she', 'he'),
    ('daughter', 'son'),
    ('woman', 'man'),
    ('actress', 'actor'),
    ('mother', 'father')
]

words_of_interest = []
for s, h in word_pairs:
    words_of_interest.append(s)
    words_of_interest.append(h)

token_ids = pipe.tokenizer.convert_tokens_to_ids(words_of_interest)
embs = pipe.text_encoder.get_input_embeddings().weight.data[token_ids]

raw_emb_dict = {}
emb_dict = {} # unit length
for i, word in enumerate(words_of_interest):
    raw_emb_dict[word] = embs[i].cpu()
    emb_dict[word] = raw_emb_dict[word] / np.linalg.norm(raw_emb_dict[word])


word_pair_diffs = []
for s, h in word_pairs:
    word_pair_diffs.append(emb_dict[s] - emb_dict[h])

In [None]:
she_vec = raw_emb_dict['woman'].numpy()
he_vec = raw_emb_dict['man'].numpy()

she_vec /= np.linalg.norm(she_vec)
he_vec /= np.linalg.norm(he_vec)

vec_mod = he_vec - she_vec 
vec_mod /= np.linalg.norm(vec_mod)

In [None]:
vec_mod.shape

In [None]:
# compare to learned delta
cur_delta = np.loadtxt('gender_data/gender_delta.txt')
cur_delta.shape

In [None]:
def cos_sim(a, b):
    return np.dot(a, b)/(np.linalg.norm(a)*np.linalg.norm(b))
cos_sim(cur_delta, vec_mod)

## Add a new concept that is concept + lamdbda (he - she)

In [None]:

# Load Concepts
def load_learned_embed_in_clip(added_embeds, text_encoder, tokenizer, token=None):
    # loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu")
    # trained_token = list(loaded_learned_embeds.keys())[0]
    trained_token = list(added_embeds.keys())[0]
    embeds = torch.tensor(added_embeds[trained_token])

    dtype = text_encoder.get_input_embeddings().weight.dtype
    embeds.to(dtype)

    token = token if token is not None else trained_token
    num_added_tokens = tokenizer.add_tokens(token)
    if num_added_tokens == 0:
      raise ValueError(f"The tokenizer already contains the token {token}. Please pass a different `token` that is not already in the tokenizer.")
    
    text_encoder.resize_token_embeddings(len(tokenizer))
    
    token_id = tokenizer.convert_tokens_to_ids(token)
    text_encoder.get_input_embeddings().weight.data[token_id] = embeds



In [None]:
concept_index = 421

scaling_fac = np.arange(-4, 4, 0.8)
old_scaling_fac = np.arange(-2, 2, 0.2)
for fac in scaling_fac:
    if fac not in old_scaling_fac:
        new_v = vecs[concept_index] + vec_mod * fac
        load_learned_embed_in_clip({f'<mod-anime-girl-{fac}>': new_v}, pipe.text_encoder, pipe.tokenizer)
# mod-anime-girl = she, he
# mod-anime-girl2 = woman, man
# mod-anime-girl3 = normed woman, man
# mod-anime-girl4 = non-normed woman, man

In [None]:

# images = pipe(f"A photo of <anime-girl>", num_images_per_prompt=1, num_inference_steps=50, guidance_scale=7.5)["sample"]
# for img in images:
#     plt.imshow(img)
#     plt.show()


In [None]:
imgs = {}
for fac in scaling_fac:
    if fac not in old_scaling_fac:
        imgs[fac] = []
        images = pipe(f"A photo of <mod-anime-girl-{fac}>", num_images_per_prompt=4, num_inference_steps=50, guidance_scale=7.5)["images"]
        imgs[fac].append(images)

In [None]:
fac_keys = list(imgs.keys())
fac_keys = sorted(fac_keys)

In [None]:
fac_keys

In [None]:
images_to_combine = []
for fac in fac_keys:
    images_to_combine.append(imgs[fac][0])

In [None]:
import PIL
from PIL import Image

def combine_images(images, imgs_per_set = 4):
  num_sets = len(images)
  

  width, height = images[0][0].size

  total_width = num_sets * width 
  max_height = imgs_per_set * height

  new_im = Image.new('RGB', (total_width, max_height))

  x_offset = 0
  y_offset = 0
  for im_set in images:
    for im in im_set:
      new_im.paste(im, (x_offset, y_offset))
      y_offset += im.size[1]
    x_offset += im.size[0]
    y_offset = 0

  return new_im

In [None]:
combined = combine_images(images_to_combine)
combined

In [None]:
images_to_combine_len_2 = []
for im_set in images_to_combine:
    inds = []
    for i in range(4):
        cur_im = im_set[i]
        cur_im = np.array(cur_im)
        if not np.all(cur_im == 0):
            inds.append(i)

    inds = np.random.choice(inds, size=2, replace=False)

    cur_set = [im_set[ind] for ind in inds]

    images_to_combine_len_2.append(cur_set)

In [None]:
combined = combine_images(images_to_combine)
combined.save('test.png')

In [None]:
combined2 = combine_images(images_to_combine_len_2, imgs_per_set=2)
combined2.save('test.png')

In [None]:
!pwd

In [None]:
# redo with coraline
def test_func(base_concept):
    # base_concept = '<sam-yang>'
    concept_ind = words.index(base_concept[1:-1])
    concept_v = vecs[concept_ind]

    she_vec = raw_emb_dict['woman'].numpy()
    he_vec = raw_emb_dict['man'].numpy()

    new_concept_v = concept_v - she_vec + he_vec

    load_learned_embed_in_clip({f'<mod-{base_concept[1:-1]}>': new_v}, pipe.text_encoder, pipe.tokenizer)

    images = pipe(f"A photo of <mod-{base_concept[1:-1]}>", num_images_per_prompt=4, num_inference_steps=50, guidance_scale=7.5)["sample"]

    for img in images:
        plt.imshow(img)
        plt.show()
    
    return new_concept_v

In [None]:
# test_func('<sam-yang>')
test_func('<anya-forger>')

In [None]:
# TODO: 
test_func('<nouns-glasses>')

In [None]:
def register_concept(base_concept='<anya-forger>'):
    # register concept that the model for some reason has not learned
    concept_ind = words.index(base_concept[1:-1])
    concept_v = vecs[concept_ind]

    load_learned_embed_in_clip({base_concept: concept_v}, pipe.text_encoder, pipe.tokenizer)

In [None]:
images = pipe(f"A photo of <anya-forger> with blue hair", num_images_per_prompt=3, num_inference_steps=50, guidance_scale=7.5)["sample"]
for img in images:
    plt.imshow(img)
    plt.show()

## Resuming

In [None]:
def get_single_word_emb(word):
    token_ids = pipe.tokenizer.convert_tokens_to_ids([word])
    embs = pipe.text_encoder.get_input_embeddings().weight.data[token_ids]
    return embs[0].cpu()

In [None]:
for x in word_pair_diffs:
    for y in word_pair_diffs:
        print(cos_sim(x, y))
    print()

In [None]:
for s, h in word_pairs:
    print(f'\n{s}', h)
    for v in word_pair_diffs:
        print(cos_sim(emb_dict[h] + v, emb_dict[s]))


## Projection

In [None]:
def scalar_project_x_onto_y(x, y):
    return np.dot(x, y) / np.dot(y, y)

In [None]:
proj_v = []
for v in vecs:
    scalar_v_fm = scalar_project_x_onto_y(v, word_pair_diffs[0])
    proj_v.append(scalar_v_fm)

In [None]:
words_and_proj_v = []
for i in range(len(proj_v)):
    words_and_proj_v.append(
        (words[i], proj_v[i])
    )

In [None]:
# TODO: should they be normalized first??
sorted_concepts_by_gender = sorted(words_and_proj_v, key=lambda x: x[1])

In [None]:
from PIL import Image

def get_image(word):
    img = cv2.imread(os.path.join(os.getcwd(), f"sd-concepts-library/{word}/0.jpeg"))
    return img

In [None]:
selected_concepts = sorted_concepts_by_gender[-10:]

selected_concepts = [x[0] for x in selected_concepts]

combined_images = [get_image(x) for x in selected_concepts]

for x in combined_images:
    plt.imshow(x)
    plt.show()

### Work (abandoned)

In [None]:
pipe.tokenizer.convert_ids_to_tokens([11001])

In [None]:
pipe.text_encoder.get_input_embeddings().weight.data.shape

In [None]:
pipe.text_encoder.get_input_embeddings().weight.data[token_id] = embeds

In [None]:
# import os

# from transformers import CLIPTextModel, CLIPTokenizer

pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16).to("cuda")
