In [1]:
# script to generate embeddings and perform similarity searches

import pandas as pd
import torch
from transformers import AutoImageProcessor, EfficientNetModel, ViTModel, AutoModel, CLIPProcessor, CLIPModel, Blip2Processor, Blip2Model
from torchvision import models, transforms
import numpy as np
import os
import re
import faiss

In [32]:
flags_df = pd.read_csv('national_flags.csv')  # Uncomment if you're loading from a CSV

In [33]:
IMAGE_DIR = "images"
def load_local_image(country_name):
    # Sanitize the country name to match the local image file naming convention
    sanitized_country_name = country_name.replace(" ", "_").replace("[", "").replace("]", "")
    
    # Path to the local image file
    image_path = os.path.join(IMAGE_DIR, f"{sanitized_country_name}.png")

    # Check if the image exists in the folder
    if os.path.exists(image_path):
        img = Image.open(image_path)
        
        # Convert image to RGB if not already in that mode
        if img.mode != 'RGB':
            img = img.convert('RGB')
        
        return img
    else:
        print(f"Image for {country_name} not found.")
        return None


In [38]:
#ViT

def extract_features_vit(country):
    image_processor = AutoImageProcessor.from_pretrained("google/vit-large-patch16-224-in21k")
    model = ViTModel.from_pretrained("google/vit-large-patch16-224-in21k")
    
    # prepare input image
    img = load_local_image(country)
    inputs = image_processor(img, return_tensors='pt')
    
    with torch.no_grad():
        outputs = model(**inputs)
    embedding = outputs.last_hidden_state
    embedding = embedding[:, 0, :].squeeze(1)
    return embedding.numpy()


In [46]:
#EfficientNet

def extract_features_efficientNet(country):
    # load pre-trained image processor for efficientnet-b7 and model weight
    image_processor = AutoImageProcessor.from_pretrained("google/efficientnet-b7")
    model = EfficientNetModel.from_pretrained("google/efficientnet-b7")
 
    # prepare input image
    img = load_local_image(country)
    inputs = image_processor(img, return_tensors='pt')
    
    with torch.no_grad():
        outputs = model(**inputs, output_hidden_states=True)
    
    embedding = outputs.hidden_states[-1]
    embedding = torch.mean(embedding, dim=[2,3])
    return embedding.numpy()
    

In [47]:
#DINO-v2

def extract_features_DINO_v2(country):
    # load pre-trained image processor for efficientnet-b7 and model weight
    image_processor = AutoImageProcessor.from_pretrained('facebook/dinov2-base')
    model = AutoModel.from_pretrained('facebook/dinov2-base')
    
    # prepare input image
    img = load_local_image(country)
    inputs = image_processor(img, return_tensors='pt')
    
    with torch.no_grad():
        outputs = model(**inputs)
    embedding = outputs.last_hidden_state
    embedding = embedding[:, 0, :].squeeze(1)
    return embedding.numpy()
    

In [56]:
#CLIP

def extract_features_clip(country):
    # load pre-trained image processor for efficientnet-b7 and model weight
    image_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
    model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")

    # prepare input image
    img = load_local_image(country)
    inputs = image_processor(images=img, return_tensors='pt', padding=True)
    
    with torch.no_grad():
        embedding = model.get_image_features(**inputs) 
    return embedding.numpy()
    

In [64]:
def extract_features_blip(country):
    image_processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
    model = Blip2Model.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16)
    img = load_local_image(country)
    inputs = image_processor(images=img, return_tensors='pt', padding=True)
    print('input shape: ', inputs['pixel_values'].shape)
    
    with torch.no_grad():
        outputs = model.get_qformer_features(**inputs)
    embedding = outputs.last_hidden_state
    embedding = embedding[:, 0, :].squeeze(1)
    return embedding.numpy()
    

In [67]:
def extract_features_vgg16(country):
    model = models.vgg16(pretrained=True) 
    model.eval()  # Set the model to evaluation mode

    # Define the transformation to preprocess the image
    preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    img = load_local_image(country)
    img_t = preprocess(img)
    batch_t = torch.unsqueeze(img_t, 0)

    with torch.no_grad():
        embedding = model(batch_t)
    return embedding.numpy()
    

In [None]:
# Extract features for all flags
flags_df['features'] = flags_df['Country'].apply(extract_features_vit)

In [123]:
#export embeddings to CSV
flags_df.to_csv('national_flag_embeddings_blip.csv', index=False)

In [None]:
#Cosine similarity with FAISS

df = pd.read_csv('embeddings/national_flag_embeddings_vit.csv')
country = "Australia"

def clean_feature_string(feature_str):
    cleaned_str = re.sub(r'[\[\]]', '', feature_str)  # Remove brackets
    cleaned_values = np.fromstring(cleaned_str, sep=' ')  # Parse values into numpy array
    return cleaned_values

# Function to get top K similar countries using FAISS
def get_top_k_similar_countries(input_country, k=5):
    print(df)
    countries = df['Country'].values
    features = np.array([clean_feature_string(f) for f in df['features'].values])
    
    # Find the index of the input country
    try:
        input_idx = list(countries).index(input_country)
    except ValueError:
        return f"Country '{input_country}' not found in the dataset."
    
    input_embedding = features[input_idx].reshape(1, -1)

    # Create a FAISS index for similarity search
    dim = features.shape[1]
    index = faiss.IndexFlatL2(dim)  # Use L2 distance (can be changed to IndexFlatIP for cosine similarity)
    
    # Add all features to the FAISS index
    index.add(features)
    
    # Search for the top K most similar countries
    distances, top_k_idx = index.search(input_embedding, k+1)  # k+1 to exclude the country itself
    
    # Return top K countries with their similarity scores
    return [(countries[i], distances[0][j]) for j, i in enumerate(top_k_idx[0]) if i != input_idx]

# Display top 5 similar flags 
top_5_countries = get_top_k_similar_countries(country, k=5)

for idx, (country, score) in enumerate(top_5_countries):
    # Load the flag image for each country from the local folder
    img = load_local_image(country)
    display(img)