In [1]:
import os
import json
import logging
import requests
from urllib.parse import urlparse
import mimetypes

In [2]:
from PIL import Image
import numpy as np
import sqlite3
import pickle
import shutil
from scipy.spatial.distance import cosine
from transformers import ViTFeatureExtractor, ViTModel
import torch

In [3]:
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
vit_model = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")

preprocessor_config.json:   0%|          | 0.00/160 [00:00<?, ?B/s]



config.json:   0%|          | 0.00/502 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

In [4]:
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler("image_fetch.log"),  
        logging.StreamHandler()               
    ]
)

In [5]:
GI_API_KEY = os.getenv('API_KEY', 'AIzaSyA43pbijmUNCtMNgSopT7VOimtgERBRXKU')
GI_SEARCH_ENGINE_ID = os.getenv('SEARCH_ENGINE_ID', 'b6f2650fd0921483a')

In [6]:
PEXELS_API_KEY = os.getenv('API_KEY', 'yxshO7kOwkkbsGf2TmXkfq2MqWaMYjdaVOja0elnSPBXPgBL645wyYhs')

In [7]:
def init_database(db_path="image_embeddings.db"):
    """
    Initialize a database at the given path, creating a table if it doesn't exist.

    Args:
        db_path (str): Path to the database file. Defaults to "image_embeddings.db".

    Returns:
        sqlite3.Connection: The established connection to the database.
    """

    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()
    cursor.execute("""
        CREATE TABLE IF NOT EXISTS embeddings (
            id INTEGER PRIMARY KEY,
            url TEXT UNIQUE,
            embedding BLOB
        )
    """)
    conn.commit()
    return conn

def copy_existing_data(input_folder, output_folder, db_file):
    """Copy existing images and database to a writable output folder.

    Args:
        
    
    """
    if os.path.exists(input_folder):
        shutil.copytree(input_folder, output_folder, dirs_exist_ok=True)
    db_input = os.path.join(input_folder, db_file)
    db_output = os.path.join(output_folder, db_file)
    if os.path.exists(db_input):
        shutil.copy(db_input, db_output)
    return db_output

def save_embedding_to_db(conn, url, embedding):
    """Save an image embedding to the database.

    Args:
        conn (sqlite3.Connection): Connection to the database.
        url (str): URL of the image.
        embedding (numpy.ndarray): The image embedding to save.
    """
    cursor = conn.cursor()
    embedding_blob = pickle.dumps(embedding)  
    try:
        cursor.execute("INSERT INTO embeddings (url, embedding) VALUES (?, ?)", (url, embedding_blob))
        conn.commit()
    except sqlite3.IntegrityError:
        logging.info(f"URL already exists in the database: {url}")

def is_similar_to_existing(conn, new_embedding, threshold=0.3):
    """
    Check if the given embedding is similar to any existing embedding in the database.
    
    Args:
        conn (sqlite3.Connection): Connection to the database.
        new_embedding (numpy.ndarray): The embedding to check.
        threshold (float, optional): The maximum cosine similarity between the new embedding and an existing one.
            Defaults to 0.1.
    
    Returns:
        bool: True if the new embedding is similar to an existing one, False otherwise.
    """
    cursor = conn.cursor()
    cursor.execute("SELECT embedding FROM embeddings")
    for row in cursor.fetchall():
        existing_embedding = pickle.loads(row[0])  
        similarity = 1 - cosine(new_embedding, existing_embedding)
        if similarity > (1 - threshold):  
            return True
    return False

def get_image_embedding(image_path):
    """
    Generate an embedding for an image using a Vision Transformer (ViT).

    Args:
        image_path (str): Path to the image file for which the embedding is to be generated.

    Returns:
        numpy.ndarray: A 1D array representing the image embedding generated by the ViT model.
    """

    # Load and preprocess the image
    image = Image.open(image_path).convert("RGB")
    inputs = feature_extractor(images=image, return_tensors="pt")

    # Get the output embeddings
    with torch.no_grad():
        outputs = vit_model(**inputs)
        embedding = outputs.last_hidden_state[:, 0, :]  # [CLS] token embedding

    # Convert to numpy array
    return embedding.squeeze().numpy()

def load_page_state(state_file="page_state.json"):
    """
    Load the page state from a file to track which pages have been crawled.

    Args:
        state_file (str): Path to the state file. Defaults to "page_state.json".

    Returns:
        dict: A dictionary containing the last fetched page number for each category.
    """
    if os.path.exists(state_file):
        with open(state_file, "r") as file:
            return json.load(file)
    return {}

def save_page_state(page_state, state_file="page_state.json"):
    """
    Save the page state to a file to track which pages have been crawled.

    Args:
        page_state (dict): A dictionary containing the last fetched page number for each category.
        state_file (str): Path to the state file. Defaults to "page_state.json".
    """
    with open(state_file, "w") as file:
        json.dump(page_state, file)

def fetch_images_from_google_image(category, num_results, state_file="page_state.json", api_key=GI_API_KEY, search_engine_id=GI_SEARCH_ENGINE_ID):
    """
    Fetch image URLs for a given category using the Google Custom Search API, starting from the last saved page state.

    Args:
        category (str): The search category or query term.
        num_results (int): The total number of image URLs to fetch.
        state_file (str, optional): Path to the state file.
        api_key (str, optional): The API key for Google Custom Search.
        search_engine_id (str, optional): The search engine ID for Google Custom Search.

    Returns:
        list[str]: A list of image URLs fetched for the given category.
    """
    page_state = load_page_state(state_file)
    start = page_state.get(category, 1) 

    url = "https://www.googleapis.com/customsearch/v1"
    image_urls = []

    while len(image_urls) < num_results:
        params = {
            'key': api_key,
            'cx': search_engine_id,
            'q': category,
            'searchType': 'image',
            'num': min(10, num_results - len(image_urls)),
            'start': start,
        }
        response = requests.get(url, params=params)
        response.raise_for_status()
        data = response.json()
        items = data.get('items', [])
        if not items:
            break
        image_urls.extend([item['link'] for item in items])
        start += len(items)

    # Update and save page state
    page_state[category] = start
    save_page_state(page_state, state_file)

    return image_urls[:num_results]

def fetch_images_from_pexels(category, num_results, state_file="pexels_page_state.json", api_key=PEXELS_API_KEY):
    """
    Fetch image URLs for a given category using the Pexels API, starting from the last saved page state.

    Args:
        category (str): The search category or query term.
        num_results (int): The total number of image URLs to fetch.
        state_file (str, optional): Path to the state file. Defaults to "pexels_page_state.json".
        api_key (str, optional): The API key for Pexels API.

    Returns:
        list[str]: A list of image URLs fetched for the given category.
    """
    headers = {"Authorization": api_key}
    base_url = "https://api.pexels.com/v1/search"

    # Load page state
    if os.path.exists(state_file):
        with open(state_file, "r") as file:
            page_state = json.load(file)
    else:
        page_state = {}

    page = page_state.get(category, 1)
    image_urls = []

    while len(image_urls) < num_results:
        params = {
            "query": category,
            "per_page": min(80, num_results - len(image_urls)), 
            "page": page,
        }
        response = requests.get(base_url, headers=headers, params=params)
        response.raise_for_status()
        data = response.json()

        photos = data.get("photos", [])
        if not photos: 
            break

        # Collect image URLs
        image_urls.extend([photo["src"]["original"] for photo in photos])
        page += 1  # Move to the next page

    # Save the updated page state
    page_state[category] = page
    with open(state_file, "w") as file:
        json.dump(page_state, file)

    return image_urls[:num_results]

def download_images_with_deduplication(urls, folder, db_path="image_embeddings.db"):
    """
    Downloads images from a list of URLs, checks for duplicates using embeddings, 
    and saves unique images to the specified folder.

    Args:
        urls (list[str]): List of image URLs to download.
        folder (str): Folder where images will be saved.
        db_path (str, optional): Path to the SQLite database for storing and checking image embeddings. Defaults to "image_embeddings.db".

    The function creates the folder if it does not exist, downloads each image, generates
    an embedding, and checks for duplicates against existing embeddings in the database. 
    If an image is not similar to existing ones, it is saved to the folder, and its 
    embedding is saved to the database. Logs info messages for successful downloads and 
    duplicates skipped, and error messages for download failures.
    """
    os.makedirs(folder, exist_ok=True)
    conn = init_database(db_path)

    for url in urls:
        try:
            response = requests.get(url, stream=True, allow_redirects=True)
            response.raise_for_status()

            # Save image temporarily for embedding
            temp_image_path = os.path.join(folder, "temp.jpg")
            with open(temp_image_path, 'wb') as file:
                file.write(response.content)

            # Generate embedding for the image
            embedding = get_image_embedding(temp_image_path)

            # Check for similarity
            if is_similar_to_existing(conn, embedding):
                logging.info(f"Duplicate image skipped: {url}")
                os.remove(temp_image_path)
            else:
                # Save the image and its embedding
                parsed_url = urlparse(url)
                file_name = os.path.basename(parsed_url.path)
                if not os.path.splitext(file_name)[1]:
                    content_type = response.headers.get('Content-Type', '')
                    ext = mimetypes.guess_extension(content_type.split(';')[0]) if content_type else '.jpg'
                    file_name += ext
                image_path = os.path.join(folder, file_name)
                os.rename(temp_image_path, image_path)
                save_embedding_to_db(conn, url, embedding)
                logging.info(f"Downloaded: {image_path}")
        except Exception as e:
            logging.error(f"Failed to download {url}: {e}")

def fetch_and_save_images(categories, num_results_per_category, 
                          base_folder="images", db_path="image_embeddings.db", state_file="page_state.json", type=None):
    """
    Fetches image URLs for each category, downloads the images with deduplication, 
    and saves them into a structured folder hierarchy while keeping track of crawled pages.

    Args:
        categories (list[str]): List of categories to fetch images for.
        num_results_per_category (int): Number of image results to fetch per category.
        base_folder (str, optional): Base folder where images will be saved. Defaults to "images".
        db_path (str, optional): Path to the SQLite database for storing image embeddings. Defaults to "image_embeddings.db".
        state_file (str, optional): Path to the state file. Defaults to "page_state.json".
        type (str, optional): Type of image source ("Google Image" or "pexels"). 
    """
    # Create the base folder if it does not exist
    os.makedirs(base_folder, exist_ok=True)
    
    # Loop through each category
    for category in categories:
        # Create category-specific folder
        category_folder = os.path.join(base_folder, category)
        os.makedirs(category_folder, exist_ok=True)

        # Fetch image URLs based on the selected source type
        if type == 'Google Image':
            # Fetch URLs using Google Custom Search API
            fetched_urls = fetch_images_from_google_image(category, num_results_per_category, state_file=state_file)
            logging.info(f"Fetched {len(fetched_urls)} URLs from Google Image for category '{category}'.")

        elif type == 'Pexels':
            # Fetch URLs using Pexels API
            fetched_urls = fetch_images_from_pexels(category, num_results_per_category, state_file=state_file)
            logging.info(f"Fetched {len(fetched_urls)} URLs from Pexels for category '{category}'.")

        else:
            # Log an error or return if the type is not recognized
            logging.error(f"Unrecognized image source type: {type}. Please specify 'Google Image' or 'pexels'.")
            return

        # Download and save images, avoiding duplicates
        download_images_with_deduplication(fetched_urls, category_folder, db_path)

In [8]:
categories = [
    # Kitchen
    "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl",

    # Indoor
    "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush",
    
    # Vehicle
    "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat",

    # Animal
    "bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe",
]

In [None]:
def daily_run():
    #input_folder = "/kaggle/input/vqa-image"
    #output_folder = "/kaggle/working/vqa-image-updated"
    #output_folder_new = "/kaggle/working/vqa-image-updated/images"
    
    db_file = "image_embeddings.db"

    num_results_per_category = 500
    gi_state_file="page_state_gi.json"
    pexels_state_file="page_state_pexels.json"
    
    # Types
    p = 'Pexels'
    gi = 'Google Image'

    db_path = copy_existing_data(input_folder, output_folder, db_file)
    
    conn = init_database(db_path)

    fetch_and_save_images(categories=categories,
                      num_results_per_category=num_results_per_category,
                      base_folder=output_folder_new,
                      db_path=db_path,
                      state_file=pexels_state_file,type=p)

    zip_path = shutil.make_archive("vqa-image-updated", 'zip', output_folder)
    logging.info(f"Updated dataset zipped at {zip_path}")

In [9]:
def first_run():
    num_results_per_category = 500
    db_path="image_embeddings.db"
    gi_state_file="page_state_gi.json"
    pexels_state_file="page_state_pexels.json"
    base_folder="images"
    
    # Types
    p = 'Pexels'
    gi = 'Google Image'

    fetch_and_save_images(categories=categories,
                      num_results_per_category=num_results_per_category,
                      base_folder=base_folder,
                      db_path=db_path,
                      state_file=pexels_state_file,type=p)

In [None]:
# first_run()

In [None]:
# daily_run()