# Image Embedding Extraction with CLIP

This notebook demonstrates how to extract image embeddings using the Open-CLIP (Contrastive Language-Image Pre-training) model. We'll go through the process step-by-step.

## Step 1: Import Required Libraries

First, we need to import the necessary Python libraries. Each library has a specific purpose in our script:

In [None]:
import torch  # For neural network operations
from transformers import CLIPProcessor, CLIPModel  # For using the CLIP model
import glob  # For finding files matching a pattern
from PIL import Image  # For opening and manipulating images
from tqdm import tqdm  # For displaying progress bars
import os  # For interacting with the operating system

## Step 2: Load Image Paths

Next, we'll find all the JPEG images in a specified folder and its subfolders:

In [None]:
# Define the source folder
source_folder = "/Volumes/Illustrated/data/extractedimages/illustratedweeklynews"

# Initialize an empty list to hold image paths
image_paths = []

# Walk through the directory structure
for root, dirs, files in os.walk(source_folder):
    # Collect all .jpg files in the current directory
    image_paths.extend(glob.glob(os.path.join(root, "*.jpg")))

# Sort the collected image paths
image_paths = sorted(image_paths)

# Print the number of collected image paths
print(f"Number of images found: {len(image_paths)}")

## Step 3: Load CLIP Model and Processor

Now we'll load the CLIP model and its associated processor:

In [None]:
model_name = "laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K"
model = CLIPModel.from_pretrained(model_name)
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
model.to(device)
processor = CLIPProcessor.from_pretrained(model_name)

print(f"Model loaded and moved to device: {device}")

## Step 4: Define Helper Functions

We'll define some helper functions to load and save embeddings:

In [None]:
def load_embeddings(file_path):
    if os.path.exists(file_path):
        return torch.load(file_path)
    return None

def save_embeddings(embeddings, file_path):
    if os.path.exists(file_path):
        existing_embeddings = torch.load(file_path)
        embeddings = torch.cat([existing_embeddings, embeddings], dim=0)
    torch.save(embeddings, file_path)

## Step 5: Define Main Embedding Extraction Function

This function does the main work of extracting embeddings from our images.

In [None]:
def get_image_embeddings(image_paths: list, processor, model, device, batch_size=32, checkpoint_path='embeddings.pt', save_every=100):
    embeddings = load_embeddings(checkpoint_path)
    start_index = len(embeddings) if embeddings is not None else 0
    embeddings = [] if embeddings is None else [embeddings]

    skipped_images = []

    for i in tqdm(range(start_index, len(image_paths), batch_size), desc="Processing Batches"):
        batch_paths = image_paths[i:i+batch_size]
        batch_images = []

        for path in batch_paths:
            try:
                image = Image.open(path)
                image.load()
                batch_images.append(image)
            except Exception as e:
                print(f"Skipping image '{path}' due to error: {str(e)}")
                skipped_images.append((path, str(e)))
           
        if not batch_images:
            continue

        with torch.no_grad():
            inputs = processor(images=batch_images, return_tensors="pt", padding=True).to(device)
            batch_embeddings = model.get_image_features(**inputs)
            embeddings.append(batch_embeddings.cpu())

        if (i // batch_size + 1) % save_every == 0:
            save_embeddings(torch.cat(embeddings, dim=0), checkpoint_path)
            embeddings = []

    if embeddings:
        save_embeddings(torch.cat(embeddings, dim=0), checkpoint_path)

    # Save skipped images information
    if skipped_images:
        with open('skipped_images.txt', 'w') as f:
            for path, error in skipped_images:
                f.write(f"{path}: {error}\n")

    return load_embeddings(checkpoint_path)

## Step 6: Run the Embedding Extraction

Now we'll run our function to extract embeddings from all our images:

In [None]:
# Main execution
checkpoint_path = 'embeddings.pt'
embeddings = get_image_embeddings(image_paths, processor, model, device, checkpoint_path=checkpoint_path)
print(f"Embeddings extracted and saved. Shape: {embeddings.shape}")

# Save final embeddings
final_path = '/Volumes/Illustrated/code/multimodal/embeddings/OpenCLIPillustratedweeklynewsfull.pt'
torch.save(embeddings, final_path)
print(f"Final embeddings saved to {final_path}")

## Restarting Interrupted Embedding Extraction\n",

If the embedding extraction process is interrupted (e.g., due to a power outage, system crash, or accidental notebook shutdown), you can easily restart it. The script is designed to resume from where it left off, thanks to the checkpoint system we've implemented. Here's how to restart the process:

1. **Ensure all cells are executed**: Make sure all the previous cells in this notebook have been executed, including the import statements, function definitions, and model loading.
2. **Check the checkpoint file**: Verify that the `embeddings.pt` file (or whatever name you've set for `checkpoint_path`) exists in your working directory. This file contains the embeddings that were successfully processed before the interruption.
3. **Run the main execution cell again**: Simply re-run the cell above this explanation. The `get_image_embeddings` function will: Load the existing embeddings from the checkpoint file, determine how many images have already been processed, and start processing from the next unprocessed image.
4. **Check for skipped images**: After the process completes, check if a `skipped_images.txt` file was created. This file lists any images that couldn't be processed, allowing you to investigate and potentially retry these specific images later.
 
By following these steps, you can easily resume the embedding extraction process after an interruption without losing any progress. The script will continue from where it left off, ensuring all your images are processed efficiently."

## Conclusion

This notebook demonstrates how to extract image embeddings using the CLIP model. The process involves loading images, processing them in batches, and saving the resulting embeddings. This can be useful for various downstream tasks such as image similarity search, clustering, or as input to other machine learning models.

Remember to check the 'skipped_images.txt' file (if it was created) to see if any images were skipped during processing.