Load all 10000 images, encode and save encoded images

In [None]:
# avoid warnings
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

In [None]:
# imports
import numpy as np
from tensorflow.keras.layers import Input, Dense, Conv2D, MaxPooling2D, UpSampling2D
from tensorflow.keras.models import Model
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow import convert_to_tensor
from sklearn.metrics.pairwise import cosine_similarity
import IPython.display as display
from PIL import Image
import matplotlib.pyplot as plt
import math
import time

In [None]:
# getall images
img_path = os.listdir('/home/jupyter/images/all/art')

In [None]:
# sort key
def sort_by_prefix(str):
    return int(str.split('_')[0])

In [None]:
# sort by index (in file name)
img_path.sort(key=sort_by_prefix)

In [None]:
# initialize variables to batch load 10000 images
img_tensors = []
batch_size = 100
total_size = len(img_path)
n_runs = math.ceil(total_size/batch_size)`

In [None]:
# batch load images into a tensor list
start_ix = 0
end_ix = start_ix + batch_size
for i in range(n_runs):
    start_time = time.time()
    for image_filename in img_path[start_ix:end_ix]:
        img = Image.open(os.path.join('/home/jupyter/images/all/art', image_filename))
        img = img.resize((224,224))
        img_tensors.append(convert_to_tensor(img))
    iterate_time = time.time()
    print(f'{end_ix} images loaded in {iterate_time - start_time} s')
    start_ix = end_ix
    end_ix = min([start_ix + batch_size, total_size])

In [None]:
# convert to array
x = np.array(img_tensors[:10000])

In [None]:
# should be 10000 images, w/ shape 224,224,3
x.shape

In [None]:
# Encoder
input_img = Input(shape=(224, 224, 3))
x = Conv2D(16, (3, 3), activation='relu', padding='same')(input_img)
x = MaxPooling2D((2, 2), padding='same')(x)
x = Conv2D(8, (3, 3), activation='relu', padding='same')(x)
encoded = MaxPooling2D((2, 2), padding='same')(x)

# Decoder
x = Conv2D(8, (3, 3), activation='relu', padding='same')(encoded)
x = UpSampling2D((2, 2))(x)
x = Conv2D(16, (3, 3), activation='relu', padding='same')(x)
x = UpSampling2D((2, 2))(x)
decoded = Conv2D(3, (3, 3), activation='sigmoid', padding='same')(x)

In [None]:
# Autoencoder
autoencoder = Model(input_img, decoded)
autoencoder.compile(optimizer='adam', loss='binary_crossentropy')

In [None]:
# load model with pretrained weights
autoencoder.load_weights('weights0039.hdf5')

In [None]:
# Extracting image representation using the encoder
encoder = Model(input_img, encoded)

In [None]:
# Encode all images
encoded_images = encoder.predict(x)

In [None]:
# save encoded images
np.save('encoded_10000.npy',encoded_images)

In [None]:
# load encoded images
encoded_images_loaded = np.load('encoded_10000.npy')

In [None]:
#check shape against encoded images pre load
encoded_images_loaded.shape

In [None]:
# function
def get_top_5_similar(input_image):
    input_representation = encoder.predict(input_image)
    
    # Compute cosine similarities between the input image and all images in the dataset
    similarities = cosine_similarity(input_representation.reshape(1, -1), 
                                     encoded_images_loaded.reshape(len(encoded_images_loaded), -1))
    
    # Get indices of the top 5 similar images
    top_5_indices = np.argsort(similarities[0])[-5:][::-1]  # -6 because the most similar one will be the image itself
    #changed so it will get top 5 - we won't have the input image in the og dataset
    return top_5_indices

In [None]:
# Example of getting top 5 similar images to x_test[0]
input_image = np.expand_dims(x[9999],axis=0)
similar_indices = get_top_5_similar(input_image)
print(f"Indices of top 5 similar images: {similar_indices}")

In [None]:
# Display the input image and its 5 most similar images
import matplotlib.pyplot as plt

fig, ax = plt.subplots(1, 6, figsize=(15, 5))
ax[0].imshow(input_image[0])
ax[0].set_title("Input Image")
ax[0].axis('off')

for i, index in enumerate(similar_indices, 1):
    ax[i].imshow(x[index])
    ax[i].set_title(f"Similar {i}")
    ax[i].axis('off')

plt.show()