In [1]:
import os
import numpy as np
import json
import logging
import requests
from urllib.parse import urlparse
import mimetypes

In [2]:
from PIL import Image
import shutil
import socket
import sqlite3
import pickle
from scipy.spatial.distance import cosine
from transformers import ViTFeatureExtractor, ViTModel
import torch
import time

In [3]:
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
vit_model = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")



In [4]:
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler("image_fetch.log"),  
        logging.StreamHandler()               
    ]
)

#### API setup

In [5]:
API_CREDENTIALS_FILE = "openverse_credentials.json"
API_REGISTER_URL = "https://api.openverse.engineering/v1/auth/register"
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")

In [6]:
def load_openverse_credentials():
    """ Load API credentials from file if they exist and are valid. """
    if not os.path.exists(API_CREDENTIALS_FILE):
        logging.info("No existing API credentials found.")
        return None

    try:
        with open(API_CREDENTIALS_FILE, "r") as f:
            credentials = json.load(f)

        client_id = credentials.get("client_id")
        client_secret = credentials.get("client_secret")

        if client_id and client_secret:
            os.environ["OPENVERSE_CLIENT_ID"] = client_id
            os.environ["OPENVERSE_CLIENT_SECRET"] = client_secret
            logging.info("Loaded API credentials successfully.")
            return credentials
        else:
            logging.error("Invalid credentials file. Deleting and re-registering.")
            os.remove(API_CREDENTIALS_FILE)  # Delete corrupted file

    except (json.JSONDecodeError, FileNotFoundError) as e:
        logging.error(f"Error reading credentials file: {e}")
        os.remove(API_CREDENTIALS_FILE)  # Delete corrupted file
    
    return None  # No valid credentials


def register_api():
    """ Register API credentials if none exist. """
    data = {
        "name": "VQA Image Project Fetcher",
        "description": "A project that fetches diverse images from Openverse to build a dataset for Visual Question Answering (VQA) research.",
        "email": "zzzz10091092zzzz@gmail.com"
    }

    try:
        response = requests.post(API_REGISTER_URL, json=data)
        response.raise_for_status()  # Raise error for bad HTTP responses

        credentials = response.json()
        with open(API_CREDENTIALS_FILE, "w") as f:
            json.dump(credentials, f, indent=4)

        os.environ["OPENVERSE_CLIENT_ID"] = credentials.get("client_id", "")
        os.environ["OPENVERSE_CLIENT_SECRET"] = credentials.get("client_secret", "")

        logging.info("Successfully registered API and saved credentials.")
        return credentials

    except requests.exceptions.RequestException as e:
        logging.error(f"API Registration Failed: {e}")
        print(f"API Registration Failed: {e}")
        return None

In [7]:
credentials = load_openverse_credentials()

if not credentials:
    logging.info("Attempting API registration...")
    credentials = register_api()

if credentials:
    print("API credentials are ready to use.")
else:
    print("Failed to set up API credentials.")

2025-02-08 14:10:30,886 - INFO - Loaded API credentials successfully.


API credentials are ready to use.


In [8]:
OPENVERSE_CLIENT_ID = os.getenv("OPENVERSE_CLIENT_ID")
OPENVERSE_CLIENT_SECRET = os.getenv("OPENVERSE_CLIENT_SECRET")

In [9]:
def load_google_api_keys(filename="google_api_keys.json"):
    with open(filename, "r") as file:
        data = json.load(file)
    return data["credentials"]

def load_pixabay_api_keys(filename="pixabay_api_keys.json"):
    with open(filename, "r") as file:
        data = json.load(file)
    return data["credentials"]

In [10]:
GI_API_CREDENTIALS = load_google_api_keys()

In [11]:
PIXABAY_API_CREDENTIALS = load_pixabay_api_keys()

In [12]:
#PEXELS_API_KEY = os.getenv('API_KEY', 'yxshO7kOwkkbsGf2TmXkfq2MqWaMYjdaVOja0elnSPBXPgBL645wyYhs')
#PEXELS_API_KEY = os.getenv('API_KEY', 'cOI2YXmRWvxcFSOc0IsWYF4njS2pi2OrujPphny84R0cg49BonbugEk3')
PEXELS_API_KEY = os.getenv('API_KEY', 'w2BFGhl5ErfGaNOC7YpjbmIrY5JtuV70aavm9ycnx6jHdXG3D6obbOIZ')

#### Fetch image

In [13]:
def init_database(db_path="image_embeddings.db"):
    """
    Initialize a database at the given path, creating a table if it doesn't exist.

    Args:
        db_path (str): Path to the database file. Defaults to "image_embeddings.db".

    Returns:
        sqlite3.Connection: The established connection to the database.
    """

    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()
    cursor.execute("""
        CREATE TABLE IF NOT EXISTS embeddings (
            id INTEGER PRIMARY KEY,
            url TEXT UNIQUE,
            embedding BLOB
        )
    """)
    conn.commit()
    return conn

def copy_existing_data(input_folder, output_folder, db_file):
    """Copy existing data from input_folder to output_folder, including the database at db_file.

    Args:
        input_folder (str): The folder containing the existing data.
        output_folder (str): The folder to copy the data to.
        db_file (str): The filename of the database to copy.

    Returns:
        str: The path to the copied database.
    """
    
    if os.path.exists(input_folder):
        shutil.copytree(input_folder, output_folder, dirs_exist_ok=True)
    db_input = os.path.join(input_folder, db_file)
    db_output = os.path.join(output_folder, db_file)
    if os.path.exists(db_input):
        shutil.copy(db_input, db_output)
    return db_output
    
    """
    os.makedirs(output_folder, exist_ok=True)
    db_input = os.path.join(input_folder, db_file)
    db_output = os.path.join(output_folder, db_file)
    if os.path.exists(db_input):
        shutil.copy(db_input, db_output)
        print("Copied db!")
    return db_output
    """
    

def save_embedding_to_db(conn, url, embedding):
    """Save an image embedding to the database.

    Args:
        conn (sqlite3.Connection): Connection to the database.
        url (str): URL of the image.
        embedding (numpy.ndarray): The image embedding to save.
    """
    cursor = conn.cursor()
    embedding_blob = pickle.dumps(embedding)  
    try:
        cursor.execute("INSERT INTO embeddings (url, embedding) VALUES (?, ?)", (url, embedding_blob))
        conn.commit()
    except sqlite3.IntegrityError:
        logging.info(f"URL already exists in the database: {url}")

def is_similar_to_existing(conn, new_embedding, threshold=0.5):
    """
    Check if the given embedding is similar to any existing embedding in the database.
    
    Args:
        conn (sqlite3.Connection): Connection to the database.
        new_embedding (numpy.ndarray): The embedding to check.
        threshold (float, optional): The maximum cosine similarity between the new embedding and an existing one.
            Defaults to 0.5.
    
    Returns:
        bool: True if the new embedding is similar to an existing one, False otherwise.
    """
    cursor = conn.cursor()
    cursor.execute("SELECT embedding FROM embeddings")
    for row in cursor.fetchall():
        existing_embedding = pickle.loads(row[0])
        similarity = 1 - cosine(new_embedding, existing_embedding)
        if similarity > threshold:  
            return True
    return False

def get_image_embedding(image_path):
    """
    Generate an embedding for an image using a Vision Transformer (ViT).

    Args:
        image_path (str): Path to the image file for which the embedding is to be generated.

    Returns:
        numpy.ndarray: A 1D array representing the image embedding generated by the ViT model.
    """

    # Load and preprocess the image
    image = Image.open(image_path).convert("RGB")
    inputs = feature_extractor(images=image, return_tensors="pt")
    
    with torch.no_grad():
        outputs = vit_model(**inputs)
        embedding = outputs.last_hidden_state[:, 0, :]
        
    # Normalize embedding
    embedding = embedding.squeeze().numpy()
    return embedding / np.linalg.norm(embedding)

def load_page_state(state_file="page_state.json"):
    """
    Load the page state from a file to track which pages have been crawled.

    Args:
        state_file (str): Path to the state file. Defaults to "page_state.json".

    Returns:
        dict: A dictionary containing the last fetched page number for each category.
    """
    if os.path.exists(state_file):
        with open(state_file, "r") as file:
            return json.load(file)
    return {}

def save_page_state(page_state, state_file="page_state.json"):
    """
    Save the page state to a file to track which pages have been crawled.

    Args:
        page_state (dict): A dictionary containing the last fetched page number for each category.
        state_file (str): Path to the state file. Defaults to "page_state.json".
    """
    with open(state_file, "w") as file:
        json.dump(page_state, file)

def fetch_images_from_google_image(category, num_results, state_file="google_page_state.json",
                                   api_credentials=GI_API_CREDENTIALS, max_retries=5, retry_delay=5):
    """
    Fetch image URLs from Google Custom Search API with API key rotation.

    Args:
        category (str): The search category or query term.
        num_results (int): The total number of image URLs to fetch.
        state_file (str, optional): Path to the state file.
        api_credentials (list, optional): List of API credentials with keys and search engine IDs.
        max_retries (int, optional): Maximum retry attempts for failed requests.
        retry_delay (int, optional): Delay in seconds between retry attempts.

    Returns:
        list[str]: A list of image URLs fetched for the given category.
    """
    base_url = "https://www.googleapis.com/customsearch/v1"
    MAX_PER_PAGE = 10

    # Load page state
    if os.path.exists(state_file):
        with open(state_file, "r") as file:
            page_state = json.load(file)
    else:
        page_state = {}

    start = page_state.get(category, 1)
    image_urls = []
    total_needed = num_results

    api_index = 0  # Track current API key index

    while len(image_urls) < total_needed:
        if api_index >= len(api_credentials):
            logging.error("All API keys exhausted. Out of queries for today.")
            return ["OUT_OF_QUERIES"]

        current_api_key = api_credentials[api_index]["api_key"]
        current_search_engine_id = api_credentials[api_index]["search_engine_id"]

        remaining = total_needed - len(image_urls)
        current_batch_size = min(MAX_PER_PAGE, remaining)

        params = {
            'key': current_api_key,
            'cx': current_search_engine_id,
            'q': category,
            'searchType': 'image',
            'num': current_batch_size,
            'start': start,
        }

        for attempt in range(max_retries):
            try:
                # Check DNS resolution first
                try:
                    socket.gethostbyname('www.googleapis.com')
                except socket.gaierror:
                    logging.error("DNS resolution failed. Retrying...")
                    time.sleep(retry_delay)
                    continue

                response = requests.get(base_url, params=params, timeout=30)
                if response.status_code == 403 or response.status_code == 429:
                    logging.warning(f"API Key {current_api_key} has hit its limit. Switching keys...")
                    api_index += 1  # Switch to the next API key
                    break  # Exit retry loop to change API key

                response.raise_for_status()
                data = response.json()

                items = data.get('items', [])
                if not items:
                    logging.info(f"No more images available after fetching {len(image_urls)} URLs")
                    page_state[category] = start
                    with open(state_file, "w") as file:
                        json.dump(page_state, file)
                    return image_urls

                new_urls = [item['link'] for item in items]
                image_urls.extend(new_urls)
                logging.info(f"Fetched {len(new_urls)} URLs starting at index {start}. Total: {len(image_urls)}/{total_needed}")

                start += len(items)
                break  # Success, exit retry loop

            except requests.exceptions.RequestException as e:
                logging.warning(f"Request attempt {attempt + 1} failed: {str(e)}. Retrying...")
                time.sleep(retry_delay)

            except Exception as e:
                logging.error(f"Unexpected error: {str(e)}")
                page_state[category] = start
                with open(state_file, "w") as file:
                    json.dump(page_state, file)
                return image_urls  # Return what we have so far

        # Small delay to avoid rate limiting
        time.sleep(5)

    # Save updated page state
    page_state[category] = start
    with open(state_file, "w") as file:
        json.dump(page_state, file)

    return image_urls[:total_needed]

def fetch_images_from_pexels(category, num_results, state_file="pexels_page_state.json", api_key=PEXELS_API_KEY, max_retries=3, retry_delay=5):
    """
    Fetch image URLs from Pexels API with improved pagination to get full requested amount.

    Args:
        category (str): The search category or query term.
        num_results (int): The total number of image URLs to fetch.
        state_file (str, optional): Path to the state file.
        api_key (str, optional): The API key for Pexels API.
        max_retries (int, optional): Maximum number of retry attempts for failed requests.
        retry_delay (int, optional): Delay in seconds between retry attempts.

    Returns:
        list[str]: A list of image URLs fetched for the given category.
    """
    headers = {"Authorization": api_key}
    base_url = "https://api.pexels.com/v1/search"
    MAX_PER_PAGE = 80

    # Load page state
    if os.path.exists(state_file):
        with open(state_file, "r") as file:
            page_state = json.load(file)
    else:
        page_state = {}

    page = page_state.get(category, 1)
    image_urls = []
    total_needed = num_results

    while len(image_urls) < total_needed:
        remaining = total_needed - len(image_urls)
        current_batch_size = min(MAX_PER_PAGE, remaining)

        # Ensure the API tries to fetch the full batch if less than MAX_PER_PAGE is requested
        if remaining < MAX_PER_PAGE:
            current_batch_size = remaining  # Don't fetch fewer images than needed
        
        params = {
            "query": category,
            "per_page": current_batch_size,
            "page": page,
        }

        for attempt in range(max_retries):
            try:
                # Check DNS resolution first
                try:
                    socket.gethostbyname('api.pexels.com')
                except socket.gaierror:
                    logging.error("DNS resolution failed. Checking network connection...")
                    time.sleep(retry_delay)
                    continue

                response = requests.get(
                    base_url,
                    headers=headers,
                    params=params,
                    timeout=30
                )
                response.raise_for_status()
                data = response.json()
                
                photos = data.get("photos", [])
                if not photos:
                    logging.info(f"No more photos available after fetching {len(image_urls)} URLs")
                    # Save the updated page state
                    page_state[category] = page
                    with open(state_file, "w") as file:
                        json.dump(page_state, file)
                    return image_urls  # No more results available
                
                new_urls = [photo["src"]["original"] for photo in photos]
                image_urls.extend(new_urls)
                logging.info(f"Fetched {len(new_urls)} URLs on page {page}. Total: {len(image_urls)}/{total_needed}")
                
                page += 1
                break  # Success, exit retry loop
                
            except requests.exceptions.RequestException as e:
                if attempt < max_retries - 1:
                    logging.warning(f"Attempt {attempt + 1} failed: {str(e)}. Retrying in {retry_delay} seconds...")
                    time.sleep(retry_delay)
                else:
                    logging.error(f"Failed to fetch images after {max_retries} attempts: {str(e)}")
                    # Save the updated page state
                    page_state[category] = page
                    with open(state_file, "w") as file:
                        json.dump(page_state, file)
                    return image_urls  # Return what we have so far
                
            except Exception as e:
                logging.error(f"Unexpected error: {str(e)}")
                return image_urls  # Return what we have so far

        # Add a small delay between pages to avoid rate limiting
        time.sleep(0.5)

    # Save the updated page state
    page_state[category] = page
    with open(state_file, "w") as file:
        json.dump(page_state, file)

    return image_urls[:total_needed]

def fetch_images_from_openverse(category, num_results, state_file="page_state_openverse.json", max_retries=5, retry_delay=10):
    """
    Fetch image URLs from Openverse API with authentication.

    Args:
        category (str): The search category or query term.
        num_results (int): The total number of image URLs to fetch.
        state_file (str, optional): Path to the state file for pagination.
        max_retries (int, optional): Maximum number of retry attempts for failed requests.
        retry_delay (int, optional): Delay in seconds between retry attempts.

    Returns:
        list[str]: A list of image URLs fetched for the given category.
    """
    base_url = "https://api.openverse.engineering/v1/images"
    MAX_PER_PAGE = 20

    # Load page state
    if os.path.exists(state_file):
        with open(state_file, "r") as file:
            page_state = json.load(file)
    else:
        page_state = {}

    page = page_state.get(category, 1)
    print(f"Starting from page {page}")
    image_urls = []
    total_needed = num_results

    headers = {
    "Authorization": "Token " + str(OPENVERSE_CLIENT_ID).strip(),
    "User-Agent": "VQA Image Project Fetcher (zzzz10091092zzzz@gmail.com)"
    }

    auth_failures = 0  # Track authentication failures
    while len(image_urls) < total_needed:
        
        remaining = total_needed - len(image_urls)
        current_batch_size = min(MAX_PER_PAGE, remaining)

        params = {
            "q": category,
            "format": "json",
            "page": page,
            "page_size": current_batch_size,
        }


        attempt = 0
        while attempt < max_retries:
            try:
                response = requests.get(base_url, params=params, headers=headers, timeout=30)

                if response.status_code == 401:
                    auth_failures += 1
                    logging.error("Authentication failed! Check your API credentials.")
                    time.sleep(30)  # Wait a minute before retrying
                    # Stop early if fail authentication multiple times
                    if auth_failures >= 3: 
                        logging.error("Multiple authentication failures! Stopping requests to avoid being blocked.")
                        page_state[category] = page
                        with open(state_file, "w") as file:
                            json.dump(page_state, file)
                        return image_urls  # Return what we have so far
                    continue
                
                if response.status_code == 429:  # Too Many Requests (rate-limited)
                    wait_time = retry_delay * (2 ** attempt)  # Exponential backoff
                    logging.warning(f"Rate limit reached. Waiting {wait_time} seconds before retrying...")
                    time.sleep(wait_time)
                    attempt += 1
                    continue

                response.raise_for_status()
                data = response.json()
                
                results = data.get("results", [])
                if not results:
                    logging.info(f"No more images available after fetching {len(image_urls)} URLs")
                    # Save the updated page state
                    page_state[category] = page
                    with open(state_file, "w") as file:
                        json.dump(page_state, file)
                    return image_urls  # No more results available
                
                new_urls = [item["url"] for item in results if "url" in item]
                image_urls.extend(new_urls)
                logging.info(f"Fetched {len(new_urls)} URLs on page {page}. Total: {len(image_urls)}/{total_needed}")

                page += 1
                break  # Success, exit retry loop
            
            except requests.exceptions.RequestException as e:
                attempt += 1
                if attempt < max_retries:
                    wait_time = retry_delay * (2 ** attempt)  # Exponential backoff
                    logging.warning(f"Attempt {attempt} failed: {str(e)}. Retrying in {wait_time} seconds...")
                    time.sleep(wait_time)
                else:
                    logging.error(f"Failed to fetch images after {max_retries} attempts: {str(e)}")
                    page_state[category] = page
                    with open(state_file, "w") as file:
                        json.dump(page_state, file)
                    return image_urls  # Return what we have so far

        # Add a small delay between pages to avoid rate limiting
        time.sleep(5)

    # Save the updated page state
    page_state[category] = page
    with open(state_file, "w") as file:
        json.dump(page_state, file)

    print(f"Total URLs fetched: {len(image_urls)}")

    return image_urls[:total_needed]

def fetch_images_from_pixabay(category, num_results, state_file="pixabay_page_state.json", 
                              api_credentials=PIXABAY_API_CREDENTIALS, max_retries=3, retry_delay=5):
    """
    Fetch image URLs from Pixabay API with automatic API key rotation.

    Args:
        category (str): The search category or query term.
        num_results (int): The total number of image URLs to fetch.
        state_file (str, optional): Path to the state file for pagination.
        api_credentials (list, optional): List of API keys for rotation.
        max_retries (int, optional): Maximum number of retry attempts for failed requests.
        retry_delay (int, optional): Delay in seconds between retry attempts.

    Returns:
        list[str]: A list of image URLs fetched for the given category.
    """
    base_url = "https://pixabay.com/api/"
    MAX_PER_PAGE = 100

    # Load pagination state
    if os.path.exists(state_file):
        with open(state_file, "r") as file:
            page_state = json.load(file)
    else:
        page_state = {}

    page = page_state.get(category, 1)
    image_urls = []
    total_needed = num_results
    api_index = 0  # Start with the first API key

    while len(image_urls) < total_needed:
        if api_index >= len(api_credentials):
            logging.error("All Pixabay API keys are exhausted. No more queries available.")
            return ["OUT_OF_QUERIES"]

        current_api_key = api_credentials[api_index]["api_key"]

        remaining = total_needed - len(image_urls)
        current_batch_size = min(MAX_PER_PAGE, remaining)

        params = {
            "key": current_api_key,
            "q": category,
            "page": page,
            "per_page": current_batch_size,
            "image_type": "photo"
        }

        for attempt in range(max_retries):
            try:
                response = requests.get(base_url, params=params, timeout=30)
                
                if response.status_code in [403, 429]:  # Rate limit exceeded or key blocked
                    logging.warning(f"API Key {current_api_key} reached its limit. Switching keys...")
                    api_index += 1  # Switch to the next API key
                    break  # Exit retry loop to change API key

                response.raise_for_status()
                data = response.json()

                hits = data.get("hits", [])
                if not hits:
                    logging.info(f"No more images available after fetching {len(image_urls)} URLs")
                    # Save the updated page state
                    page_state[category] = page
                    with open(state_file, "w") as file:
                        json.dump(page_state, file)
                    return image_urls  # No more results available

                new_urls = [hit["webformatURL"] for hit in hits if "webformatURL" in hit]
                image_urls.extend(new_urls)
                logging.info(f"Fetched {len(new_urls)} URLs on page {page}. Total: {len(image_urls)}/{total_needed}")

                page += 1
                break  # Success, exit retry loop

            except requests.exceptions.RequestException as e:
                logging.warning(f"Request attempt {attempt + 1} failed: {str(e)}. Retrying...")
                time.sleep(retry_delay)

            except Exception as e:
                logging.error(f"Unexpected error: {str(e)}")
                # Save the updated page state
                page_state[category] = page
                with open(state_file, "w") as file:
                    json.dump(page_state, file)
                return image_urls  # Return what we have so far

        # Add delay to avoid rate limiting
        time.sleep(1)

    # Save updated pagination state
    page_state[category] = page
    with open(state_file, "w") as file:
        json.dump(page_state, file)

    return image_urls[:total_needed]

def process_image_batch(urls, output_folder, db_path, category, target_count, fetch_func, state_file):
    """
    Process a batch of image URLs with deduplication checks and fetch more if needed.
    Includes improved file handling and resource management.

    Args:
        urls (list): List of image URLs to process.
        output_folder (str): Base output folder path.
        db_path (str): Path to embeddings database.
        category (str): Category name for subfolder organization.
        target_count (int): Desired number of unique images.
        fetch_func (callable): Function to fetch more image URLs.
        state_file (str): Path to the state file for tracking pages.
    """
    # Ensure category folder exists
    category_folder = os.path.join(output_folder, category)
    os.makedirs(category_folder, exist_ok=True)

    # Initialize database connection
    conn = init_database(db_path)

    try:
        # Create temporary folder for downloads
        temp_dir = os.path.join(output_folder, 'temp')
        os.makedirs(temp_dir, exist_ok=True)

        processed_count = 0
        processed_urls = set()

        while processed_count < target_count:
            if not urls:
                print(f"Fetching more URLs for {category}...")
                additional_urls = fetch_func(
                    category=category,
                    num_results=target_count - processed_count,
                    state_file=state_file
                )
                urls.extend([url for url in additional_urls if url not in processed_urls])

                if not urls:
                    print(f"Exhausted available images for {category}. Stopping early.")
                    break

            url = urls.pop(0)
            if url in processed_urls:
                continue

            processed_urls.add(url)

            # Generate unique temporary filename for each download
            temp_filename = f'temp_image_{int(time.time() * 1000)}.jpg'
            temp_path = os.path.join(temp_dir, temp_filename)

            try:
                # Download image
                response = requests.get(url, stream=True, timeout=10)
                response.raise_for_status()
                
                # Write to temp file
                with open(temp_path, 'wb') as f:
                    f.write(response.content)

                # Generate embedding
                embedding = get_image_embedding(temp_path)

                # Check for duplicates
                if not is_similar_to_existing(conn, embedding, threshold=0.5):
                    # Save the image if it's unique
                    parsed_url = urlparse(url)
                    file_name = os.path.basename(parsed_url.path)

                    if not os.path.splitext(file_name)[1]:
                        content_type = response.headers.get('Content-Type', '')
                        ext = mimetypes.guess_extension(content_type.split(';')[0]) if content_type else '.jpg'
                        file_name += ext

                    final_path = os.path.join(category_folder, file_name)
                    
                    # Use atomic move operation
                    shutil.move(temp_path, final_path)

                    # Save embedding
                    save_embedding_to_db(conn, url, embedding)
                    processed_count += 1
                    print(f"Saved new image: {file_name} ({processed_count}/{target_count})")

                else:
                    print(f"Skipped duplicate: {url}")
                    os.remove(temp_path)

            except requests.exceptions.RequestException as e:
                print(f"Network error processing {url}: {e}")
            except Exception as e:
                print(f"Error processing {url}: {e}")
            finally:
                # Clean up temp file if it exists
                if os.path.exists(temp_path):
                    try:
                        os.remove(temp_path)
                    except OSError:
                        pass  # Ignore errors during cleanup

    finally:
        # Cleanup temp directory and close database connection
        try:
            if os.path.exists(temp_dir):
                shutil.rmtree(temp_dir, ignore_errors=True)
        except Exception as e:
            print(f"Error cleaning up temp directory: {e}")
        
        conn.close()

    return processed_count

In [29]:
categories = [
    # Kitchen
    #"bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl",

    # Indoor 
    #"book", "clock", "vase", "scissors", "teddy bear", "hair drier",
    #"toothbrush", --pixabay has a few toothbrush, pexel has run all
    
    # Vehicle 
    #"bicycle", "car", "bus", "train", "motorcycle", "truck", "boat",
    #"airplane",   

    # Animal
    #"bird", "cat", "dog", "horse",
    #"sheep", "cow", 
    "bear",
    #"elephant", 
    #"zebra",
    #"giraffe",

]

In [15]:
def daily_run(image_source, KAGGLE_INPUT, KAGGLE_OUTPUT, num_results_per_category):
    # Paths
    db_path = os.path.join(KAGGLE_OUTPUT, "image_embeddings.db")
    image_output_folder = os.path.join(KAGGLE_OUTPUT, "images")
    
    # Copy existing database if it exists
    existing_db = os.path.join(KAGGLE_INPUT, "image_embeddings.db")
    if os.path.exists(existing_db) and not os.path.exists(db_path):
        shutil.copy(existing_db, db_path)
        print("Copied existing database")
    
    conn = init_database(db_path)
    conn.close()
    
    # State files for different sources
    state_files = {
        'Pexels': "page_state_pexels.json",
        'Google Image': "page_state_gi.json",
        'Openverse': "page_state_openverse.json",
        'Pixabay': "page_state_pixabay.json"
    }
    
    # Select fetch function based on source
    fetch_functions = {
        'Pexels': fetch_images_from_pexels,
        'Google Image': fetch_images_from_google_image,
        'Openverse': fetch_images_from_openverse,
        'Pixabay': fetch_images_from_pixabay
    }

    if image_source not in fetch_functions:
        raise ValueError(f"Invalid image source. Choose from: {list(fetch_functions.keys())}")
    
    fetch_func = fetch_functions[image_source]
    state_file = state_files[image_source]
    
    # Process each category
    for category in categories:
        # Count existing images in the category folder
        category_folder = os.path.join(image_output_folder, category)
        existing_images = len([f for f in os.listdir(category_folder) 
                                if os.path.isfile(os.path.join(category_folder, f))])
        
        # Calculate remaining images to fetch
        remaining_images = max(0, num_results_per_category - existing_images)
        
        print(f"Fetching {remaining_images} images for {category} using {fetch_func.__name__}")
        
        if remaining_images <= 0:
            print(f"{category} already has {existing_images} images. Skipping.")
            continue
        
        print(f"\nProcessing category: {category}")
        print(f"Existing images: {existing_images}, Remaining to fetch: {remaining_images}")
        
        urls = fetch_func(
            category=category,
            num_results=remaining_images,
            state_file=state_file
        )
        print(f"Fetched {len(urls)} URLs for {category}")
        
        # Process images with deduplication and additional fetching if needed
        processed_count = process_image_batch(
            urls=urls,
            output_folder=image_output_folder,
            db_path=db_path,
            category=category,
            target_count=remaining_images,
            fetch_func=fetch_func,
            state_file=state_file
        )
        
        print(f"Completed processing {category}: {processed_count}/{remaining_images} unique images saved")

In [16]:
def print_image_counts(image_output_folder):
    """
    Print the number of images in each category folder.
    
    Args:
        image_output_folder (str): Base folder containing category subfolders
    """
    print("\nCurrent Image Counts:")
    
    # Ensure the folder exists
    if not os.path.exists(image_output_folder):
        print(f"Output folder {image_output_folder} does not exist.")
        return
    
    # Track total images
    total_images = 0
    categories_with_images = 0
    
    # Iterate through each category
    for category in categories:
        category_folder = os.path.join(image_output_folder, category)
        
        # Count image files (excluding directories)
        if os.path.exists(category_folder):
            image_count = len([f for f in os.listdir(category_folder) 
                                if os.path.isfile(os.path.join(category_folder, f))])
            
            print(f"{category}: {image_count} images")
            total_images += image_count
            
            if image_count > 0:
                categories_with_images += 1
    
    print(f"\nTotal Images: {total_images}")
    print(f"Categories with Images: {categories_with_images}/{len(categories)}")

In [17]:
num_results_per_category = 600

#### Methods

'Pexels' \
'Google Image' \
'Openverse' \
'Pixabay'

In [30]:
KAGGLE_INPUT = "D:/Project/VQA/vqa_dataset"
KAGGLE_OUTPUT = "D:/Project/VQA/vqa_dataset"
image_source='Google Image'

In [31]:
print_image_counts(os.path.join(KAGGLE_OUTPUT, "images"))


Current Image Counts:
bear: 552 images

Total Images: 552
Categories with Images: 1/1


In [None]:
try:
    daily_run(
        image_source=image_source, 
        KAGGLE_INPUT=KAGGLE_INPUT, 
        KAGGLE_OUTPUT=KAGGLE_OUTPUT,
        num_results_per_category=num_results_per_category
    )
except Exception as e:
    logging.error(f"Error in daily_run: {str(e)}")

In [33]:
print_image_counts(os.path.join(KAGGLE_OUTPUT, "images"))


Current Image Counts:
bear: 584 images

Total Images: 584
Categories with Images: 1/1
