In [1]:
%cd ..

c:\Users\HP\OneDrive - University of Moratuwa\Desktop\E-Vision-Projects\Shelf_Product_Count_Generation


In [2]:
import torch
import torch.nn as nn
from transformers import AutoModel
from torchvision import transforms
from PIL import Image
from pathlib import Path
import numpy as np
import faiss
import pickle
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [11]:
# 2. DINOv2 Model with Fine-tuning Head
class FineTunedDINOv2(nn.Module):
    def __init__(self, model_name='facebook/dinov2-base', embedding_dim=512, freeze_backbone=False):
        super().__init__()
        
        # Load pre-trained DINOv2
        self.dinov2 = AutoModel.from_pretrained(model_name)
        
        # Get embedding dimension from DINOv2
        if 'base' in model_name:
            dinov2_dim = 768
        elif 'small' in model_name:
            dinov2_dim = 384
        elif 'large' in model_name:
            dinov2_dim = 1024
        else:
            dinov2_dim = 768  # default
        
        # Freeze backbone if needed (for faster training)
        if freeze_backbone:
            for param in self.dinov2.parameters():
                param.requires_grad = False
        
        # Fine-tuning head (projection layer)
        self.projection_head = nn.Sequential(
            nn.Linear(dinov2_dim, 1024),
            nn.GELU(),
            nn.Dropout(0.3),
            nn.Linear(1024, embedding_dim),
            nn.LayerNorm(embedding_dim)
        )
        
    def forward(self, pixel_values):
        """
        Forward pass through DINOv2
        pixel_values: preprocessed images (batch_size, 3, 224, 224)
        """
        # DINOv2 forward pass
        outputs = self.dinov2(pixel_values=pixel_values)
        
        # Get CLS token (first token) - this is the image embedding
        cls_token = outputs.last_hidden_state[:, 0, :]  # (batch_size, dinov2_dim)
        
        # Project to desired embedding dimension
        embedding = self.projection_head(cls_token)
        
        # L2 normalize
        embedding = nn.functional.normalize(embedding, p=2, dim=1)
        
        return embedding

In [13]:
# 2. Load trained model and get embeddings
def load_trained_dinov2_model(checkpoint_path='models/dinov2_finetuned.pth', device=None):
    """Load the trained DINOv2 model from checkpoint"""
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Load checkpoint
    checkpoint = torch.load(checkpoint_path, map_location=device)
    
    # Create model with same architecture
    model = FineTunedDINOv2(
        model_name=checkpoint['model_name'],
        embedding_dim=checkpoint['embedding_dim'],
        freeze_backbone=False
    )
    
    # Load weights
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()  # Set to evaluation mode
    model = model.to(device)
    
    print(f"Loaded model from {checkpoint_path}")
    print(f"Model: {checkpoint['model_name']}")
    print(f"Embedding dim: {checkpoint['embedding_dim']}")
    print(f"Trained for {checkpoint['epoch']+1} epochs")
    print(f"Best loss: {checkpoint['loss']:.4f}")
    
    return model, checkpoint, device

def get_image_transform_dinov2(image_size=224):
    """Get the same transform used during training (without augmentation)"""
    return transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                           std=[0.229, 0.224, 0.225])
    ])

def get_embedding_from_image_dinov2(model, image_path_or_pil, transform, device):
    """
    Get embedding for a single image using DINOv2
    
    Args:
        model: Trained FineTunedDINOv2
        image_path_or_pil: Path to image (str) or PIL Image
        transform: Image transform function
        device: torch device
    
    Returns:
        embedding: numpy array of shape (embedding_dim,)
    """
    # Load image
    if isinstance(image_path_or_pil, str):
        image = Image.open(image_path_or_pil).convert('RGB')
    else:
        image = image_path_or_pil.convert('RGB')
    
    # Transform
    image_tensor = transform(image).unsqueeze(0).to(device)
    
    # Get embedding
    with torch.no_grad():
        embedding = model(image_tensor)
        embedding = embedding.cpu().numpy().flatten()
    
    return embedding

def get_embeddings_batch_dinov2(model, image_paths, transform, device, batch_size=32):
    """
    Get embeddings for multiple images efficiently using DINOv2
    
    Args:
        model: Trained FineTunedDINOv2
        image_paths: List of image paths or PIL Images
        transform: Image transform function
        device: torch device
        batch_size: Batch size for processing
    
    Returns:
        embeddings: numpy array of shape (num_images, embedding_dim)
    """
    embeddings_list = []
    
    for i in tqdm(range(0, len(image_paths), batch_size), desc="Extracting embeddings"):
        batch_paths = image_paths[i:i+batch_size]
        batch_images = []
        
        for img_path in batch_paths:
            try:
                if isinstance(img_path, str):
                    image = Image.open(img_path).convert('RGB')
                else:
                    image = img_path.convert('RGB')
                batch_images.append(transform(image))
            except Exception as e:
                print(f"Error loading {img_path}: {e}")
                continue
        
        if not batch_images:
            continue
            
        # Stack into batch tensor
        batch_tensor = torch.stack(batch_images).to(device)
        
        # Get embeddings
        with torch.no_grad():
            batch_embeddings = model(batch_tensor)
            embeddings_list.append(batch_embeddings.cpu().numpy())
    
    # Concatenate all batches
    if embeddings_list:
        embeddings = np.vstack(embeddings_list)
    else:
        embeddings = np.array([])
    
    return embeddings

In [14]:
# 3. Build FAISS index with all reference images
def build_reference_index_dinov2(model, reference_dir='data/reference_images', 
                                 transform=None, device=None, batch_size=32):
    """
    Build FAISS index from all reference images using DINOv2
    
    Returns:
        index: FAISS index
        product_ids: List of product IDs for each embedding
        image_paths: List of image paths
    """
    if transform is None:
        transform = get_image_transform_dinov2()
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    reference_path = Path(reference_dir)
    all_image_paths = []
    product_ids = []
    
    # Collect all image paths
    product_folders = sorted([d for d in reference_path.iterdir() if d.is_dir()],
                            key=lambda x: int(x.name) if x.name.isdigit() else 0)
    
    for product_folder in product_folders:
        product_id = product_folder.name
        image_files = sorted(product_folder.glob('*.jpg')) + \
                     sorted(product_folder.glob('*.jpeg')) + \
                     sorted(product_folder.glob('*.png'))
        
        for image_path in image_files:
            all_image_paths.append(str(image_path))
            product_ids.append(product_id)
    
    print(f"Processing {len(all_image_paths)} reference images...")
    
    # Get embeddings for all images
    embeddings = get_embeddings_batch_dinov2(model, all_image_paths, transform, device, batch_size)
    
    # Build FAISS index (using Inner Product for cosine similarity since embeddings are normalized)
    dimension = embeddings.shape[1]
    index = faiss.IndexFlatIP(dimension)  # Inner Product = cosine similarity for normalized vectors
    
    # Convert to float32 for FAISS
    embeddings_f32 = embeddings.astype('float32')
    index.add(embeddings_f32)
    
    print(f"Built FAISS index with {index.ntotal} vectors")
    
    return index, product_ids, all_image_paths

In [15]:
# 4. Find similar products
def find_similar_products_dinov2(query_embedding, index, product_ids, image_paths=None, 
                                 top_k=5, threshold=None):
    """
    Find top-k similar products using cosine similarity (Inner Product)
    
    Args:
        query_embedding: numpy array of shape (embedding_dim,)
        index: FAISS index (IndexFlatIP)
        product_ids: List of product IDs
        image_paths: Optional list of image paths
        top_k: Number of results to return
        threshold: Optional minimum similarity threshold (0-1)
    
    Returns:
        results: List of dicts with product_id, similarity, rank
    """
    # Reshape for FAISS (needs batch dimension)
    query_embedding = query_embedding.reshape(1, -1).astype('float32')
    
    # Search (returns cosine similarity scores, higher = more similar)
    similarities, indices = index.search(query_embedding, top_k)
    
    results = []
    for i, (similarity, idx) in enumerate(zip(similarities[0], indices[0])):
        if threshold is None or similarity >= threshold:
            result = {
                'product_id': product_ids[idx],
                'similarity': float(similarity),
                'similarity_percent': f"{similarity*100:.2f}%",
                'rank': i + 1
            }
            if image_paths:
                result['image_path'] = image_paths[idx]
            results.append(result)
    
    return results

In [25]:
model, checkpoint, device = load_trained_dinov2_model('models/dinov2_finetuned.pth')
transform = get_image_transform_dinov2()

Loaded model from models/dinov2_finetuned.pth
Model: facebook/dinov2-base
Embedding dim: 512
Trained for 5 epochs
Best loss: 0.2295


In [26]:
# Example: Get embedding for a single image
image_path = 'data/test_images/cropped_image_3.jpg'
embedding = get_embedding_from_image_dinov2(model, image_path, transform, device)
print(f"Embedding shape: {embedding.shape}")
print(f"Embedding norm: {np.linalg.norm(embedding):.4f}")  # Should be ~1.0

Embedding shape: (512,)
Embedding norm: 1.0000


In [27]:
# Cell 5: Build FAISS index
index, product_ids, image_paths = build_reference_index_dinov2(model, batch_size=16)


Processing 166 reference images...


Extracting embeddings: 100%|██████████| 11/11 [00:36<00:00,  3.34s/it]

Built FAISS index with 166 vectors





In [29]:
# Cell 9: Query new image
query_image_path = r'data\test_images\cropped_image_3.jpg'
query_embedding = get_embedding_from_image_dinov2(model, query_image_path, transform, device)

# 4. Find matches
matches = find_similar_products_dinov2(query_embedding, index, product_ids, image_paths, 
                                      top_k=5, threshold=0.7)

# 5. Display results
for match in matches:
    print(f"Product {match['product_id']}: Similarity={match['similarity_percent']}")

Product 15832: Similarity=100.00%
Product 675: Similarity=100.00%
Product 79: Similarity=100.00%
Product 15832: Similarity=100.00%
Product 1003: Similarity=99.99%
