In [None]:
import torch
import clip
from PIL import Image
import numpy as np
from diffusers import StableDiffusionPipeline
import torch
import os
import random
from pytorch_fid import fid_score

In [4]:
# empty cuda 
import torch
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    print("CUDA cache emptied.")
else:
    print("CUDA is not available.")

# gc 
import gc
gc.collect()
    

CUDA cache emptied.


0

In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [5]:
# !pip install diffusers[flax] accelerate transformers
# !pip install xformers

In [None]:
# !bash launch_textualInversion.sh

In [6]:
clip_model, preprocess = clip.load("ViT-B/32", device=device)


In [7]:
def generate_images(prompt,num_images,output_dir,output_subdir,embeddings):
    # Create the directory if it doesn't exist
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    if not os.path.exists(f"{output_dir}/{output_subdir}"):
        os.makedirs(f"{output_dir}/{output_subdir}")

    model_id = "runwayml/stable-diffusion-v1-5"
    pipe = StableDiffusionPipeline.from_pretrained(model_id,torch_dtype=torch.float16).to(device)

    pipe.load_textual_inversion(embeddings)
    gen_images = []
    for i in range(num_images):
        image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
        # extarct image name from output_subdir
        image_name = output_subdir.split("/")[-1]
      
        image.save(f"{output_dir}/{output_subdir}/{image_name}_{i}.png")
        print(f"Image {i} saved.")
        gen_images.append(f"{output_dir}/{output_subdir}/{image_name}_{i}.png")
    return gen_images


def compute_clip_similarity(gen_images,prompt):

    # encode text
    text = clip.tokenize([prompt]).to(device)
    # Encode the text prompt
    with torch.no_grad():
        text_features = clip_model.encode_text(text)
    # Normalize the text features
    text_features /= text_features.norm(dim=-1, keepdim=True)

    similarity_scores = []

    # check if gen_images is a list of images or a direct

    for image in gen_images:
        # Preprocess the image
        image = preprocess(Image.open(image)).unsqueeze(0).to(device)
        # Encode the image
        with torch.no_grad():
            image_features = clip_model.encode_image(image)
        # Normalize the image features
        image_features /= image_features.norm(dim=-1, keepdim=True)
        # Calculate the similarity score
        similarity_score = (image_features @ text_features.T).squeeze(0)
        similarity_scores.append(similarity_score.item())
    return np.mean(similarity_scores)


def fid_score_calc(real_img_dir,gen_img_dir):
    # from fid_score import calculate_fid_given_paths
    # Calculate FID score
    fid_score_val = fid_score.calculate_fid_given_paths([real_img_dir, gen_img_dir], batch_size=2, device=device,dims=2048)
    return fid_score_val


def resize_images(image_dir, target_size=(256, 256)):

    valid_extensions = ('.png', '.jpg', '.jpeg')
    for img_name in os.listdir(image_dir):
        if img_name.endswith(valid_extensions):
            img_path = os.path.join(image_dir, img_name)
            img = Image.open(img_path).resize(target_size)
            img.save(img_path)


from PIL import Image
import os
import matplotlib.pyplot as plt

def display_images_in_directory(directory):
    # Get a list of all image files in the directory
    image_files = [f for f in os.listdir(directory) if f.endswith(('.png', '.jpg', '.jpeg'))]

    # Create a figure to display the images
    plt.figure(figsize=(10, 5))

    # Loop through the image files and display each one
    for i, image_file in enumerate(image_files):
        image_path = os.path.join(directory, image_file)
        img = Image.open(image_path)
        plt.subplot(3, 3, i + 1)
        plt.imshow(img)
        plt.axis('off')
        plt.title(image_file)

    plt.tight_layout()
    plt.show()


import os
import matplotlib.pyplot as plt
from PIL import Image

def display_gen_imgs(dir, max_images=10):
    # Get a list of image files (sorted for consistency)
    image_files = sorted([f for f in os.listdir(dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))])
    image_files = image_files[:max_images]
    num_images = len(image_files)

    # Grid dimensions (e.g., 5x2 for 10 images)
    cols = 5
    rows = (num_images + cols - 1) // cols

    # Plot the images
    plt.figure(figsize=(4 * cols, 4 * rows))
    for i, image_file in enumerate(image_files):
        image_path = os.path.join(dir, image_file)
        img = Image.open(image_path)
        plt.subplot(rows, cols, i + 1)
        plt.imshow(img)
        plt.axis('off')
        plt.title(image_file, fontsize=8)

    plt.tight_layout()
    plt.show()




## Wikiart dataset

1. **Impressionism** class

In [None]:
num_images = 10
output_dir = "Prompt_Images/Impressionism"
prompt_test = "<impressionism-style> A serene lakeside scene at sunset, with soft brushstrokes and warm pastel colors, capturing the movement of light on water."
# generate images
gen_images = generate_images(prompt_test,num_images,output_dir,"lakeside", "embeddings/Impressionism")
# compute clip similarity
similarity_score = compute_clip_similarity(gen_images,prompt_test)
print(f"CLIP Similarity Score: {similarity_score}")




In [12]:
# calculate fid score
real_img_dir = "wikiart_curated/kaggle/working/wikiart_curated/Impressionism"
resize_images(real_img_dir)

gen_img_dir = f"{output_dir}/lakeside"

print("Real Images:", len(os.listdir(real_img_dir)))
print("Generated Images:", len(os.listdir(gen_img_dir)))

# len(gen_images)
print(gen_img_dir)


# resize images in gen
resize_images(gen_img_dir)
# calculate fid score
fid_score_val = fid_score_calc(real_img_dir, gen_img_dir)
print(f"FID Score: {fid_score_val}")

Real Images: 40
Generated Images: 10
Prompt_Images/Impressionism/lakeside


100%|██████████| 20/20 [00:00<00:00, 39.17it/s]
100%|██████████| 5/5 [00:00<00:00, 28.46it/s]


FID Score: 310.3736006785706


In [None]:
num_images = 10
output_dir = "Prompt_Images/Impressionism"
prompt_test = "<impressionism-style> A woman seated indoors near a sunlit balcony, with soft, textured brushstrokes and warm, muted tones, evoking a serene and reflective atmosphere watching another woman outside the balcony."

# generate images
gen_images = generate_images(prompt_test,num_images,output_dir,"Women_on_Balcony", "embeddings/Impressionism")
# compute clip similarity
similarity_score = compute_clip_similarity(gen_images,prompt_test)
print(f"CLIP Similarity Score: {similarity_score}")

In [15]:
real_img_dir = "wikiart_curated/kaggle/working/wikiart_curated/Impressionism"
resize_images(real_img_dir)
gen_img_dir = f"{output_dir}/Women_on_Balcony"
resize_images(gen_img_dir)

fid_score_val = fid_score_calc(real_img_dir, gen_img_dir)
print(f"FID Score: {fid_score_val}")

100%|██████████| 20/20 [00:00<00:00, 37.23it/s]
100%|██████████| 8/8 [00:00<00:00, 29.78it/s]


FID Score: 388.29444870692066


In [None]:
# display images
gen_img_dir = f"{output_dir}/Women_on_Balcony"
display_gen_imgs(gen_img_dir)

2. **Abstract Expresssioism** class

In [None]:
num_images = 10
output_dir = "Prompt_Images/Abstract_Expressionism"
prompt_test = "<Abstract_Expressionism-style> A dynamic explosion of colors and shapes, with bold brushstrokes and a sense of movement, capturing the essence of abstract expressionism."
# generate images
gen_images = generate_images(prompt_test,num_images,output_dir,"abstract_expressionism", "embeddings/Abstract_Expressionism")
# compute clip similarity
similarity_score = compute_clip_similarity(gen_images,prompt_test)
print(f"CLIP Similarity Score: {similarity_score}")

In [8]:
# calculate fid score
real_img_dir = "wikiart_curated/kaggle/working/wikiart_curated/Abstract_Expressionism"
resize_images(real_img_dir)
gen_img_dir = f"{output_dir}/abstract_expressionism"
print("Real Images:", len(os.listdir(real_img_dir)))
print("Generated Images:", len(os.listdir(gen_img_dir)))
# len(gen_images)
print(gen_img_dir)
resize_images(real_img_dir)
# resize images in gen
resize_images(gen_img_dir)
# calculate fid score
fid_score_val = fid_score_calc(real_img_dir, gen_img_dir)
print(f"FID Score: {fid_score_val}")

Real Images: 40
Generated Images: 10
Prompt_Images/Abstract_Expressionism/abstract_expressionism


100%|██████████| 20/20 [00:00<00:00, 28.62it/s]
100%|██████████| 5/5 [00:00<00:00, 15.69it/s]


FID Score: 446.05996104902334


3. **Action_painting** class

In [None]:
dir = "wikiart_curated/kaggle/working/wikiart_curated/Action_painting"

# print 10 images in the directory
import os

import random

import matplotlib.pyplot as plt
import matplotlib.image as mpimg

import numpy as np

# Get a list of all files in the directory
files = os.listdir(dir)
# Filter the list to include only image files

image_files = [f for f in files if f.endswith(('.png', '.jpg', '.jpeg'))]

# Randomly select 10 images
random_images = random.sample(image_files, 10)
# Display the images
plt.figure(figsize=(20, 10))
for i, img_file in enumerate(random_images):
    img_path = os.path.join(dir, img_file)
    img = mpimg.imread(img_path)
    plt.subplot(2, 5, i + 1)
    plt.imshow(img)
    plt.axis('off')
plt.show()


In [None]:
num_images = 10
output_dir = "Prompt_Images/Action_painting"
prompt_test = "<Action_painting-style> A dark and chaotic action painting, aggressive brushstrokes, bold streaks of black, red, and white, raw emotional energy on canvas"
# generate images
gen_images = generate_images(prompt_test,num_images,output_dir,"action_painting", "embeddings/Action_painting")
# compute clip similarity
similarity_score = compute_clip_similarity(gen_images,prompt_test)
print(f"CLIP Similarity Score: {similarity_score}")




In [None]:
action_painting_dir = "Prompt_Images/Action_painting/action_painting"

display_gen_imgs(action_painting_dir)

In [11]:
real_img_dir = "wikiart_curated/kaggle/working/wikiart_curated/Action_painting"

gen_img_dir = f"{output_dir}/action_painting"
print("Real Images:", len(os.listdir(real_img_dir)))
print("Generated Images:", len(os.listdir(gen_img_dir)))
# len(gen_images)
print(gen_img_dir)
# resize images in gen
resize_images(real_img_dir)
# resize images in gen
resize_images(gen_img_dir)
# calculate fid score
fid_score_val = fid_score_calc(real_img_dir, gen_img_dir)
print(f"FID Score: {fid_score_val}")

Real Images: 40
Generated Images: 10
Prompt_Images/Action_painting/action_painting


100%|██████████| 20/20 [00:00<00:00, 24.67it/s]
100%|██████████| 5/5 [00:00<00:00, 13.15it/s]


FID Score: 465.28043514036847


4. **Analytical_Cubism** class

In [None]:
dir = "wikiart_curated/kaggle/working/wikiart_curated/Analytical_Cubism"
# print 10 images in the directory
import os
import random

import matplotlib.pyplot as plt

import matplotlib.image as mpimg

import numpy as np

# Get a list of all files in the directory
files = os.listdir(dir)
# Filter the list to include only image files
image_files = [f for f in files if f.endswith(('.png', '.jpg', '.jpeg'))]
# Randomly select 10 images

random_images = random.sample(image_files, 10)
# Display the images
plt.figure(figsize=(20, 10))

for i, img_file in enumerate(random_images):
    img_path = os.path.join(dir, img_file)
    img = mpimg.imread(img_path)
    plt.subplot(2, 5, i + 1)
    plt.imshow(img)
    plt.axis('off')
plt.show()


In [11]:
num_images = 10
output_dir = "Prompt_Images/Analytical_Cubism"

prompt_test = "<Analytical_Cubism-style> A distorted cubist portrait, angular planes, multiple perspectives merged into one, abstract facial features, using earthy browns, greys, and ochre."
# generate images
gen_images = generate_images(prompt_test,num_images,output_dir,"analytical_cubism", "embeddings/Analytical_Cubism")
# compute clip similarity
similarity_score = compute_clip_similarity(gen_images,prompt_test)
print(f"CLIP Similarity Score: {similarity_score}")


Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

Image 0 saved.


  0%|          | 0/50 [00:00<?, ?it/s]

Image 1 saved.


  0%|          | 0/50 [00:00<?, ?it/s]

Image 2 saved.


  0%|          | 0/50 [00:00<?, ?it/s]

Image 3 saved.


  0%|          | 0/50 [00:00<?, ?it/s]

Image 4 saved.


  0%|          | 0/50 [00:00<?, ?it/s]

Image 5 saved.


  0%|          | 0/50 [00:00<?, ?it/s]

Image 6 saved.


  0%|          | 0/50 [00:00<?, ?it/s]

Image 7 saved.


  0%|          | 0/50 [00:00<?, ?it/s]

Image 8 saved.


  0%|          | 0/50 [00:00<?, ?it/s]

Potential NSFW content was detected in one or more images. A black image will be returned instead. Try again with a different prompt and/or seed.


Image 9 saved.
CLIP Similarity Score: 0.33416748046875


In [12]:
# fid 
real_img_dir = "wikiart_curated/kaggle/working/wikiart_curated/Analytical_Cubism"
resize_images(real_img_dir)
gen_img_dir = f"{output_dir}/analytical_cubism"
print("Real Images:", len(os.listdir(real_img_dir)))
print("Generated Images:", len(os.listdir(gen_img_dir)))
# len(gen_images)
# print(gen_img_dir)
# resize images in gen
resize_images(gen_img_dir)
# calculate fid score
fid_score_val = fid_score_calc(real_img_dir, gen_img_dir)
print(f"FID Score: {fid_score_val}")

Real Images: 40
Generated Images: 10


100%|██████████| 20/20 [00:00<00:00, 26.37it/s]
100%|██████████| 5/5 [00:00<00:00, 21.40it/s]


FID Score: 348.2420193210146


## Dreambooth dataset

1. **Cat** class

In [None]:
# display all images in a directory
dir= "dataset/cat"

# Example usage
display_images_in_directory(dir)

In [None]:
num_images = 10
output_dir = "Prompt_Images/Dreambooth/cat"
prompt = "A photo of <cat> in a grassy field during sunset"
output_subdir = "cat_sunset"
embeddings = "embeddings/dreambooth/cat"
gen_images = generate_images(prompt,num_images,output_dir,output_subdir,embeddings=embeddings)
# compute clip similarity
similarity_score = compute_clip_similarity(gen_images,prompt)
print(f"Similarity score for |{output_dir}| -> |{output_subdir}| : {similarity_score}")

In [None]:
# fid score
real_img_dir = "dataset/cat"
gen_img_dir = f"{output_dir}/{output_subdir}"
print("Real Images:", len(os.listdir(real_img_dir)))
print("Generated Images:", len(os.listdir(gen_img_dir)))



from PIL import Image



resize_images(real_img_dir)
resize_images(gen_img_dir)
print("Images resized!")


fid_score_val = fid_score_calc(real_img_dir,gen_img_dir)
print(f"FID score for |{real_img_dir}| -> |{gen_img_dir}| : {fid_score_val}")

Real Images: 5
Generated Images: 10
Valid Real Images: 5
Valid Generated Images: 10
Images resized!


100%|██████████| 3/3 [00:00<00:00,  5.29it/s]
100%|██████████| 5/5 [00:00<00:00, 25.11it/s]


FID score for |dataset/cat| -> |Prompt_Images/Dreambooth/cat/cat_sunset| : 201.6868119294298


2. **dog** class

In [None]:
dir = "dataset/dog"

# print all iamges in dir
from PIL import Image
import os

import matplotlib.pyplot as plt

display_images_in_directory(dir)
    

In [None]:
num_images = 10

output_dir = "Prompt_Images/Dreambooth/dog"
prompt = "A photo of <dog> in a grassy field during sunset"
output_subdir = "dog_sunset"
embeddings = "embeddings/dreambooth/dog"
gen_images = generate_images(prompt,num_images,output_dir,output_subdir,embeddings=embeddings)


In [10]:
# compute clip similarity
similarity_score = compute_clip_similarity(gen_images,prompt)
print(f"Similarity score for |{output_dir}| -> |{output_subdir}| : {similarity_score}")

Similarity score for |Prompt_Images/Dreambooth/dog| -> |dog_sunset| : 0.26429443359375


In [11]:
# fid score
real_img_dir = "dataset/dog"
gen_img_dir = f"{output_dir}/{output_subdir}"
print("Real Images:", len(os.listdir(real_img_dir)))
print("Generated Images:", len(os.listdir(gen_img_dir)))

resize_images(real_img_dir)
# resize images in gen
resize_images(gen_img_dir)


# calculate fid score
fid_score_val = fid_score_calc(real_img_dir,gen_img_dir)

print(f"FID score for |{real_img_dir}| -> |{gen_img_dir}| : {fid_score_val}")

Real Images: 5
Generated Images: 10


100%|██████████| 3/3 [00:00<00:00,  7.17it/s]
100%|██████████| 5/5 [00:00<00:00, 23.97it/s]


FID score for |dataset/dog| -> |Prompt_Images/Dreambooth/dog/dog_sunset| : 174.79761776799165


In [None]:
# test for 1 img
num_images = 1
test_prompt = "A photo of <dog>"
output_dir = "Prompt_Images/test"
output_subdir = "testing"
embeddings = "embeddings/dreambooth/dog"
gen_images = generate_images(test_prompt,num_images,output_dir,output_subdir,embeddings=embeddings)



3. **shiny_sneaker** class

In [None]:
dir = "dataset/shiny_sneaker"
display_images_in_directory(dir)


In [None]:
num_images = 10
output_dir = "Prompt_Images/Dreambooth/shiny_sneaker"
prompt = "A photo of <sneaker> on a reflective surface, studio lighting, high detail, 8k resolution"
output_subdir = "sneaker_studio"
embeddings = "embeddings/dreambooth/shiny_sneaker"
gen_images = generate_images(prompt,num_images,output_dir,output_subdir,embeddings=embeddings)
# compute clip similarity
similarity_score = compute_clip_similarity(gen_images,prompt)
print(f"Similarity score for |{output_dir}| -> |{output_subdir}| : {similarity_score}")

In [9]:
# fid score
real_img_dir = "dataset/shiny_sneaker"  
gen_img_dir = f"{output_dir}/{output_subdir}"
print("Real Images:", len(os.listdir(real_img_dir)))
print("Generated Images:", len(os.listdir(gen_img_dir)))
# resize images in real
resize_images(real_img_dir)
# resize images in gen
resize_images(gen_img_dir)
# calculate fid score
fid_score_val = fid_score_calc(real_img_dir,gen_img_dir)
print(f"FID score for |{real_img_dir}| -> |{gen_img_dir}| : {fid_score_val}")

Real Images: 6
Generated Images: 10


100%|██████████| 3/3 [00:00<00:00,  5.63it/s]
100%|██████████| 5/5 [00:00<00:00, 19.48it/s]


FID score for |dataset/shiny_sneaker| -> |Prompt_Images/Dreambooth/shiny_sneaker/sneaker_studio| : 323.9680244701582


4. **pink_sunglasses** class

In [None]:
dir = "dataset/pink_sunglasses"
display_images_in_directory(dir)

In [None]:
num_images = 10
output_dir = "Prompt_Images/Dreambooth/pink_sunglasses"
prompt = "A pink <sunglasses> on a soft silk fabric, aesthetic and elegant."
output_subdir = "sunglasses_silk"
embeddings = "embeddings/dreambooth/pink_sunglasses"
gen_images = generate_images(prompt,num_images,output_dir,output_subdir,embeddings=embeddings)
# compute clip similarity
similarity_score = compute_clip_similarity(gen_images,prompt)
print(f"Similarity score for |{output_dir}| -> |{output_subdir}| : {similarity_score}")

In [None]:
# fid score
real_img_dir = "dataset/pink_sunglasses"
gen_img_dir = f"{output_dir}/{output_subdir}"
print("Real Images:", len(os.listdir(real_img_dir)))
print("Generated Images:", len(os.listdir(gen_img_dir)))
# resize images in real
resize_images(real_img_dir)
# resize images in gen
resize_images(gen_img_dir)
# calculate fid score
fid_score_val = fid_score_calc(real_img_dir,gen_img_dir)
print(f"FID score for |{real_img_dir}| -> |{gen_img_dir}| : {fid_score_val}")

Real Images: 6
Generated Images: 10


100%|██████████| 3/3 [00:00<00:00, 14.83it/s]
100%|██████████| 5/5 [00:00<00:00, 24.33it/s]


FID score for |dataset/pink_sunglasses| -> |Prompt_Images/Dreambooth/pink_sunglasses/sunglasses_silk| : 303.30047267339734


: 

In [None]:
num_images = 10
output_dir = "Prompt_Images/Dreambooth/pink_sunglasses"

prompt = "A pair of elegant pink <sunglasses> resting on a marble countertop, surrounded by soft natural light."
output_subdir = "sunglasses_marble"
embeddings = "embeddings/dreambooth/pink_sunglasses"
gen_images = generate_images(prompt,num_images,output_dir,output_subdir,embeddings=embeddings)
# compute clip similarity
similarity_score = compute_clip_similarity(gen_images,prompt)
print(f"Similarity score for |{output_dir}| -> |{output_subdir}| : {similarity_score}")

In [None]:
gen_img_dir = f"{output_dir}/{output_subdir}"
display_gen_imgs(gen_img_dir)

In [23]:
# fid score
real_img_dir = "dataset/pink_sunglasses"
gen_img_dir = f"{output_dir}/{output_subdir}"
print("Real Images:", len(os.listdir(real_img_dir)))
print("Generated Images:", len(os.listdir(gen_img_dir)))
# resize images in real
resize_images(real_img_dir)
# resize images in gen
resize_images(gen_img_dir)
# calculate fid score
fid_score_val = fid_score_calc(real_img_dir,gen_img_dir)
print(f"FID score for |{real_img_dir}| -> |{gen_img_dir}| : {fid_score_val}")


Real Images: 6
Generated Images: 10


100%|██████████| 3/3 [00:00<00:00,  7.82it/s]
100%|██████████| 5/5 [00:00<00:00, 14.94it/s]


FID score for |dataset/pink_sunglasses| -> |Prompt_Images/Dreambooth/pink_sunglasses/sunglasses_marble| : 233.4216257075353
