In [None]:
import torch, torchvision
import torchvision.models as models
import torchvision.transforms as transforms
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'
import numpy as np
from PIL import Image
import faiss
from analyze_image import ImageAnalyzer
from labels import labels
from dataset_creation import DatasetManager, RegexLabelExtractor

## Data prep for pets data following this [notebook](https://colab.research.google.com/github/akashmehra/blog/blob/fastbook/lessons/_notebooks/2021-07-20-pets_classifier.ipynb#scrollTo=ekNHMAUtklXS)
Classes have been moved to dataset_creation.py

In [None]:
# recreate dataset without train/test split
base_img_dir = r'.\pets\oxford-iiit-pet\images'
paths = [path for path in sorted(os.listdir(base_img_dir)) if path.endswith('.jpg')]
pattern = '(.+)_\d+.jpg$'
regex_label_extractor = RegexLabelExtractor(pattern)
dm =  DatasetManager(base_img_dir, paths, regex_label_extractor,
                                 seed=42)
data=dm.dataset

In [None]:
# len = 7390
len(data)

***
## Creating FAISS Index

In [None]:
# taking the output from this snippet and putting in labels.py
# df was from Pet Training ResNet34.ipynb

# classes = df.to_dict()
# class_to_idx = classes['label_name']
# class_to_idx
# # flip keys and values 
# labels = {v: k for k, v in class_to_idx.items()}
labels

In [None]:
analyze = ImageAnalyzer(data, labels, './artifacts/model_pets.pt')
analyze.model

In [None]:
# create PyTorch dataloader
pet_loader = torch.utils.data.DataLoader(valid_dataset,
                                     batch_size = 8,
                                     shuffle = False)

In [None]:
# get every image's emebedding and put it in the same tensor
# should end with shape [7390, 512]
all_embeddings = torch.tensor([])
for i, batch in enumerate(tqdm(pet_loader)):
    for j in range(len(batch[0])):
        embed = analyze.getEmbeddings(batch[0][j].unsqueeze(0))
        all_embeddings=torch.cat([all_embeddings, embed],0)

# save it so we don't have to do it again         
torch.save(all_embeddings, "embeddings_trained_34.pt")

In [None]:
# load in embeddings from trained ResNet
trained_embeddings = torch.load('./artifacts/embeddings_trained_34.pt')

### Examining Some of the Learned Features 

In [None]:
# look at image embeddings with the highest value for a given index, select the top 10 largest values and output their respective images
def display_best_images(feature_index, all_embeddings):
    top_ten = sorted(range(len(all_embeddings)), key=lambda k: all_embeddings[k][feature_index].item(),reverse=True)[:10]
    top_images = torch.stack([valid_dataset[i][0] for i in top_ten])
    analyze.imshow(torchvision.utils.make_grid(top_images, nrow=5, padding=2))

In [None]:
display_best_images(1, trained_embeddings)

In [None]:
display_best_images(51, trained_embeddings)

In [None]:
display_best_images(101, trained_embeddings)

***
### A few tests before moving on

In [None]:
target_img = Image.open('./test images/ragdoll.jpg')
target_transform = analyze.transform(target_img)
target_embeddings = analyze.getEmbeddings(target_transform)
target_img

In [None]:
top_cosine = analyze.cosine_similar_images(target_embeddings, trained_embeddings)
analyze.show_best_results(top_cosine, data)

***
### Finally creating the FAISS Index

In [None]:
# normalize embeddings
faiss.normalize_L2(trained_embeddings.numpy())
trained_embeddings
# create ids that match those in the dataset
ids = np.arange(0, trained_embeddings.shape[0], step = 1)

In [None]:
# https://gist.github.com/mdouze/773b2e1b42ac50f700407f3a727921e5
# create faiss index
# use IP, which will do cosine similarity since the embeddings are normalized
dim = trained_embeddings.shape[1]  # 512 features
index = faiss.IndexIDMap2(faiss.IndexFlatIP(512))
index.add_with_ids(trained_embeddings.numpy(), ids.astype(np.int64))
index.is_trained

In [None]:
# faiss.write_index(index, "pet_faiss_index")
index = faiss.read_index("./artifacts/pet_faiss_index")
index.ntotal

In [None]:
# sanity check
distance, indices = index.search(target_embeddings.numpy(), 10)
top_images = torch.stack([data[j][0] for j in indices[0]])
analyze.imshow(torchvision.utils.make_grid(top_images, nrow = 5, padding = 2))

In [None]:
indices