In [None]:
#Install Dependencies

!pip install torch torchvision pinecone pillow cloudinary

import torch
import torchvision.models as models
import torchvision.transforms as transforms
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np
import os
from tqdm import tqdm
from google.colab import drive
import json
import cloudinary
import cloudinary.uploader
from google.colab import userdata
from pinecone import Pinecone, ServerlessSpec


PROJECT_NAME = 'visual-search'
COMMON_DIMENSION = 2048
DATASET_PATH = '/content/drive/MyDrive/datasets/afhq/train'

In [None]:
# Mount Google Drive
drive.mount('/content/drive')

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

In [None]:
# Dataset class for AFHQ
class AFHQDataset(Dataset):
  def __init__(self, root_dir, limit=500, transform = None):
        self.root_dir = root_dir
        self.transform = transform
        self.images = []
        self.labels = []
        self.label_map = {'cat': 0, 'dog': 1, 'wild': 2}

        #Load image paths
        for category in ['wild']: #['cat', 'dog', 'wild']
          category_dir = os.path.join(root_dir, category)
          if os.path.exists(category_dir):
            for image_name in os.listdir(category_dir):
              if image_name.endswith(('.jpg', '.png', '.jpeg')) and limit > 0:
                self.images.append(os.path.join(category_dir, image_name))
                self.labels.append(category)
                limit -= 1
          else:
            print(f"Category {category} not found")

        print(f"Loaded {len(self.images)} images")
        print(f"Categories: {dict(zip(*np.unique(self.labels, return_counts=True)))}")


  def __len__(self):
      return len(self.images)

  def __getitem__(self, index):
    image_path = self.images[index]
    image = Image.open(image_path).convert('RGB')
    label = self.labels[index]

    if self.transform:
      image = self.transform(image)

    return image, label, image_path



In [None]:
# Feature Extraction Model

class FeatureExtractor:
  def __init__(self, model_name='resetnet50', use_clip=False):
    if use_clip:
      # Use clip for better semantic understanding
      !pip install transformers
      from transformers import CLIPModel, CLIPProcessor

      self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
      self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
      self.model.to(device)
      self.model.eval()
      self.dim = 512  # CLIP dimension

    else:
      # Use ResNet50 for visual features
      model = models.resnet50(pretrained=True)

      # Remove the final classification layer
      self.model = nn.Sequential(*list(model.children())[:-1])
      self.model.to(device)
      self.model.eval()
      self.dim = 2048  # ResNet50 dimension

    self.use_clip = use_clip

    # Describe preprocessing for images
    self.transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                          std=[0.229, 0.224, 0.225]),

    ])

  def extract(self, image_path):
      if self.use_clip:
          image = Image.open(image_path).convert('RGB')
          inputs = self.processor(images=image, return_tensors="pt").to(device)

          with torch.no_grad():
              features = self.model.get_image_features(**inputs)
              features = features.cpu().numpy().flatten()
      else:
          image = Image.open(image_path).convert('RGB')
          image_tensor = self.transform(image).unsqueeze(0).to(device)

          with torch.no_grad():
              features = self.model(image_tensor)
              features = features.cpu().numpy().flatten()

      # L2 normalize for better similarity search
      features = features / np.linalg.norm(features)
      return features

In [None]:
# Extract Feature from AFHQ Dataset

# Setup Cloudinary
cloudinary.config(
    cloud_name = userdata.get('CLOUDINARY_CLOUD_NAME'),
    api_key = userdata.get('CLOUDINARY_API_KEY'),
    api_secret = userdata.get('CLOUDINARY_API_SECRET'),
    secure=True
)


# Initialize feature extractor
extractor = FeatureExtractor(model_name='resnet50', use_clip=False)
print(f"Feature dimension: {extractor.dim}")


dataset = AFHQDataset(DATASET_PATH, limit=3200, transform=extractor.transform)

# extract features for all image categories
features_data = []
batch_size = 32

print("Extracting features...")
for i in tqdm(range(0, len(dataset), batch_size)):
  batch_end = min(i + batch_size, len(dataset))

  for j in range(i, batch_end):
    image, label, image_path = dataset[j]

    try:

      features = extractor.extract(image_path)

      # Upload to Cloudinary
      cloudinary_result = cloudinary.uploader.upload(
          image_path,
          folder=f"afhq/{label}",
          use_filename=True,
          unique_filename=False,
          overwrite=True
      )

      # Create feature dictionary
      feature_dict = {
          'id': f"{label}_{os.path.basename(image_path)}",
          'values': features.tolist(),
          'metadata': {
              'project': PROJECT_NAME,
              'category': label,
              'filename': os.path.basename(image_path),
              'path': image_path,
              'animal_type': label,
              'image_url': cloudinary_result.get("secure_url")
          }
      }

      features_data.append(feature_dict)

    except Exception as e:
            print(f"Error processing {image_path}: {e}")

print(f"Extracted features for {len(features_data)} images")

In [None]:
# Setup Pinecone

pc = Pinecone(
        api_key=userdata.get('PINECONE_API_KEY')
    )


# Create index
INDEX_NAME = "multi-project-index"

# Create/Use existing index if it exists
if not pc.has_index(INDEX_NAME):
  print(f"Creating new index: {INDEX_NAME}")
  pc.create_index(
      name=INDEX_NAME,
      dimension=COMMON_DIMENSION,
      metric="cosine",
      spec=ServerlessSpec(cloud='aws', region='us-east-1')
    )
else:
    print(f"Using existing index: {INDEX_NAME}")

# Wait for index to be ready
import time
while not pc.describe_index(INDEX_NAME).status['ready']:
    time.sleep(1)

index = pc.Index(INDEX_NAME)

# Check current stats
stats = index.describe_index_stats()
print(f"\nCurrent index stats: {stats}")
print(f"Total vectors: {stats.total_vector_count}")
if stats.namespaces:
    print(f"Namespaces: {list(stats.namespaces.keys())}")

In [None]:
# Upload to current namespace
namespace = PROJECT_NAME
batch_size = 100

print(f"\nUploading to namespace: {namespace}")

# Check how much space we have left
current_vectors = stats.total_vector_count
free_tier_limit = 100000
space_left = free_tier_limit - current_vectors
print(f"Space used: {current_vectors}/{free_tier_limit} ({current_vectors/free_tier_limit*100:.1f}%)")
print(f"Space for new vectors: {space_left}")

if len(features_data) > space_left:
    print(f"WARNING: Not enough space! Reducing to {space_left} vectors")
    features_data = features_data[:space_left]

# Upload in batches
for i in tqdm(range(0, len(features_data), batch_size)):
    batch = features_data[i:i + batch_size]

    vectors = [
        {
            "id": item['id'],
            "values": item['values'],
            "metadata": item['metadata']
        }
        for item in batch
    ]

    # Upsert to namespace
    index.upsert(vectors=vectors, namespace=namespace)

stats = index.describe_index_stats()
print(f"\nUpdated index stats:")
print(f"Total vectors: {stats.total_vector_count}")
if namespace in stats.namespaces:
    print(f"Vectors in {namespace}: {stats.namespaces[namespace]['vector_count']}")