<a href="https://colab.research.google.com/github/sanjaysamuels/ImageSimilarityFinder-SiameseNetwork/blob/master/ImageSimilartySearch_SiameseNetwork.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install datasets

# Image Similarity Search

To find similar snacks based on the snack image you feed it with, I have employed a Siamese network trained with a triplet loss. This network was trained on a dataset sourced from Huggingface, specifically 'Matthijs/snacks.' The Siamese network is a type of neural network architecture that learns to distinguish between similar and dissimilar items, making it ideal for tasks like image similarity. The triplet loss function encourages the network to reduce the distance between anchor images (the input snack image) and positive images (similar snacks) while increasing the distance between anchor images and negative images (dissimilar snacks). The model achieved a minimum accuracy of 0.8, ensuring it reliably identifies snack similarities. Moreover, the image dataset has been persistently stored in an appropriate storage system for convenient future use.



In [None]:
from datasets import load_dataset

# Load the "Matthijs/snacks" dataset
dataset = load_dataset("Matthijs/snacks")

In [None]:
# Code below allows colab to have access to user drive to save trained image model
# Also used to load images to find similar snacks images
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from tqdm import tqdm
from PIL import Image
import numpy as np
from datasets import load_dataset

import torch.nn.functional as F

import torch.nn.functional as F


'''
Class SiameseNetwork represents a CNN neural network with two pairs of
convolutional layers, followed by ReLU activation functions and max-pooling layers.

- During training, the forward method is used to process anchor-positive and
anchor-negative pairs in a triplet using the forward_one method.
- The anchor and positive inputs go through the network to compute their embeddings,
which are then compared to evaluate similarity or dissimilarity.
'''
class SiameseNetwork(nn.Module):
  def __init__(self):
      super(SiameseNetwork, self).__init__()
      self.cnn = nn.Sequential(
          nn.Conv2d(3, 64, 5),
          nn.ReLU(inplace=True),
          nn.MaxPool2d(2, 2),
          nn.Conv2d(64, 128, 5),
          nn.ReLU(inplace=True),
          nn.MaxPool2d(2, 2)
      )
      self.fc = nn.Sequential(
          nn.AdaptiveAvgPool2d((1, 1)),  # Global Average Pooling (GAP) layer
          nn.Flatten(),
          nn.Linear(128, 256),
          nn.ReLU(inplace=True),
          nn.Linear(256, 128)
      )

  def forward_one(self, x):
      x = self.cnn(x)
      x = self.fc(x)
      return x

  def forward(self, input1, input2):
      output1 = self.forward_one(input1)
      output2 = self.forward_one(input2)
      return output1, output2



'''
The TripleLoss class is designed for the embeddings (feature vectors) of
anchor-positive pairs to be closer in distance while pushing the embeddings of
anchor-negative pairs further apart.

Reference: https://towardsdatascience.com/siamese-network-triplet-loss-b4ca82c1aec8#:~:text=you%20can%20train%20the%20network,negative%20image%20must%20be%20high.
'''
class TripletLoss(nn.Module):
  def __init__(self, margin=0.9):
      super(TripletLoss, self).__init__()
      self.margin = margin

  def forward(self, anchor, positive, negative):
      distance_positive = torch.norm(anchor - positive, p=2, dim=1)
      distance_negative = torch.norm(anchor - negative, p=2, dim=1)
      loss = torch.clamp(self.margin + distance_positive - distance_negative, min=0.0)
      return loss.mean()

'''
CustomDataset class is used to create a PyTorch-compatible dataset from an input dataset
'''
class CustomDataset(Dataset):
  def __init__(self, dataset, transform=None):
      self.dataset = dataset
      self.transform = transform

  def __len__(self):
      return len(self.dataset)

  def __getitem__(self, idx):
      item = self.dataset[idx]
      image = np.array(item["image"])
      label = item["label"]
      if self.transform:
          image = Image.fromarray(image)
          image = self.transform(image)
      return image, label

# Define transformations to resize images to a common size
transformImage = transforms.Compose([
  transforms.Resize((256, 256)),
  transforms.ToTensor(),
])

custom_dataset = CustomDataset(dataset['train'], transform=transformImage)
data_loader = DataLoader(custom_dataset, batch_size=32, shuffle=True)

# Initialize the Siamese network and optimizer
siamese_net = SiameseNetwork()
optimizer = optim.Adam(siamese_net.parameters(), lr=0.0001)

# Define the triplet loss function
triplet_loss = TripletLoss()

# Training loop
num_epochs = 7
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
siamese_net.to(device)

'''
For loop represents the training epochs to train the Siamese Network model.
The training is done in bach sizes of 32 which is then split into:
Anchor, Positive, and Negative Samples for each batch
'''
for epoch in range(num_epochs):
  total_loss = 0.0
  siamese_net.train()
  for batch in tqdm(data_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
      batch = [item.to(device) for item in batch]
      inputs, labels = batch

      # Split the batch into anchor, positive, and negative samples
      anchor_indices = torch.randint(0, len(inputs), (inputs.size(0),))
      positive_indices = torch.randint(0, len(inputs), (inputs.size(0),))
      negative_indices = torch.randint(0, len(inputs), (inputs.size(0),))

      anchor_inputs = inputs[anchor_indices]
      positive_inputs = inputs[positive_indices]
      negative_inputs = inputs[negative_indices]

      # Forward pass
      anchor_outputs, positive_outputs = siamese_net(anchor_inputs, positive_inputs)
      anchor_outputs, negative_outputs = siamese_net(anchor_inputs, negative_inputs)

      # Calculate triplet loss
      loss = triplet_loss(anchor_outputs, positive_outputs, negative_outputs)

      # Backpropagation and optimization
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

      total_loss += loss.item()

  print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/len(data_loader)}")

# Save the trained Siamese network in your drive
# torch.save(siamese_net.state_dict(), "siamese_net_snacks.pth")
torch.save(siamese_net.state_dict(), '/content/drive/My Drive/Colab Notebooks/siamese_net_snacks.pth')

In [None]:
import torch
from torchvision import transforms
from PIL import Image
import numpy as np

# Here load the saved Siamese network that is trained from previous step
load_siamese_net = SiameseNetwork()
load_siamese_net.load_state_dict(torch.load("siamese_net.pth"))

# Define a function to compute embeddings for an image
def get_embedding(image_path):
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
    ])
    image = Image.open(image_path)
    image = transform(image).unsqueeze(0)  # Add batch dimension
    with torch.no_grad():
        embedding = load_siamese_net.forward_one(image)
    return embedding


input_image_embedding = get_embedding("/content/drive/My Drive/Colab Notebooks/oranges.jpeg")

train_embeddings = []
'''
For loop below is used to iterate through the training dataset and get image embeddings
which will then be used with cosine_similarity to find embeddings that are similar to the input image.
'''
for idx in range(len(custom_dataset)):
    image, _ = custom_dataset[idx]
    with torch.no_grad():
        embedding = load_siamese_net.forward_one(image.unsqueeze(0).to(device))
    train_embeddings.append(embedding)

# Calculate similarity scores (cosine similarity)
def cosine_similarity(embedding1, embedding2):
    similarity = torch.nn.functional.cosine_similarity(embedding1, embedding2)
    return similarity.item()

In [None]:
import matplotlib.pyplot as plt

# Defining a high similarity threshold to ensure only highly similar images are returned
similarity_threshold = 0.9

# For loop to find similar images that are in the training dataset
similar_images = []
for idx, train_embedding in enumerate(train_embeddings):
  similarity = cosine_similarity(input_image_embedding, train_embedding)
  if similarity >= similarity_threshold:
      similar_images.append((idx, similarity))

# Sort similar images by similarity score in descending order
similar_images.sort(key=lambda x: x[1], reverse=True)

# Display top 5 similar images
num_similar_images_to_display = min(5, len(similar_images))
for i in range(num_similar_images_to_display):
  idx, similarity = similar_images[i]
  image, _ = custom_dataset[idx]

  # Convert image tensor to numpy array
  image_np = image.permute(1, 2, 0).numpy()

  # Display the image
  plt.figure()
  plt.imshow(image_np)
  plt.title(f"Similarity: {similarity:.2f}")
  plt.axis('off')

plt.show()

'''
Un-comment code below to print the similary score of the similar_images array
'''
# Print similar images and their similarity scores
# for idx, similarity in similar_images:
#     image, label = custom_dataset[idx]
#     print(f"Similarity: {similarity:.2f}")

In [None]:
# Load the test dataset to test the trained model's accuracy
test_dataset = CustomDataset(dataset['test'], transform=transform)

# Compute embeddings for test images
test_embeddings = []
for idx in range(len(test_dataset)):
    image, _ = test_dataset[idx]
    with torch.no_grad():
        embedding = siamese_net.forward_one(image.unsqueeze(0).to(device))
    test_embeddings.append(embedding)

validation_similarity_threshold = 0.8

# Validation
true_positives = 0
true_negatives = 0
false_positives = 0
false_negatives = 0

for idx in range(0, len(test_embeddings), 2):
  # Get pair of images at idx and idx+1 to represent a positve pair
  anchor_embedding = test_embeddings[idx]
  positive_embedding = test_embeddings[idx + 1]

  # Calculate similarity
  similarity = cosine_similarity(anchor_embedding, positive_embedding)

  # If condition to determine if the pair is similar or dissimilar
  if similarity >= validation_similarity_threshold:
      # Similar pair (true positive)
      true_positives += 1
  else:
      # Dissimilar pair (false negative)
      false_negatives += 1

for idx in range(0, len(test_embeddings), 2):
  # Get pair of images at idx and idx+1 to represent a negative pair
  anchor_embedding = test_embeddings[idx]
  negative_embedding = test_embeddings[idx + 1]

  # Calculate similarity
  similarity = cosine_similarity(anchor_embedding, negative_embedding)

  # If condition to determine if the pair is similar or dissimilar
  if similarity < validation_similarity_threshold:
      # Dissimilar pair (true negative)
      true_negatives += 1
  else:
      # Similar pair (false positive)
      false_positives += 1

# Calculate accuracy, precision, recall, and F1-score
accuracy = (true_positives + true_negatives) / len(test_dataset)
precision = true_positives / (true_positives + false_positives)
recall = true_positives / (true_positives + false_negatives)
f1_score = 2 * (precision * recall) / (precision + recall)

print(f"Accuracy: {accuracy:.2f}")
print(f"Precision: {precision:.2f}")
print(f"Recall: {recall:.2f}")
print(f"F1 Score: {f1_score:.2f}")
