In [None]:
# Install dependencies
# !pip install weaviate-client sentence-transformers pillow numpy tqdm

### Initial Setup

In [3]:
# Import libraries
import weaviate
from weaviate.classes.config import Configure, Property, DataType, Multi2VecField
import os
import base64
import json
from PIL import Image
import numpy as np
from tqdm import tqdm
from sentence_transformers import SentenceTransformer
import time
import io

In [None]:
# Initialize CLIP model
clip_model = SentenceTransformer('clip-ViT-B-32')

### Utility Functions

In [4]:
def read_image_data(image_path):
    """Read image as base64 and PIL object, caching to avoid redundant reads"""
    with open(image_path, "rb") as img_file:
        raw_data = img_file.read()
        base64_data = base64.b64encode(raw_data).decode("utf-8")
        pil_img = Image.open(io.BytesIO(raw_data))
    return base64_data, pil_img

def process_metadata(metadata, max_chars=300):
    """Process captions from metadata"""
    captions = metadata.get("captions", [])
    concatenated = " ".join(captions)[:max_chars] if captions else "No captions available"
    return concatenated

### Schema Creation

In [6]:
def create_flickr_schema_multi2vec(client):
    """Create schema for multi2vec-clip"""
    if client.collections.exists("Flickr30k_multi2vec"):
        print("Flickr30k_multi2vec exists, skipping creation")
        return
    client.collections.create(
        "Flickr30k_multi2vec",
        properties=[
            Property(name="image", data_type=DataType.BLOB),
            Property(name="image_id", data_type=DataType.TEXT),
            Property(name="captions", data_type=DataType.TEXT),
        ],
        vectorizer_config=[
            Configure.NamedVectors.multi2vec_clip(
                name="image_vector",
                image_fields=[Multi2VecField(name="image", weight=0.7)],
                text_fields=[Multi2VecField(name="captions", weight=0.3)]
            )
        ]
    )
    print("Created Flickr30k_multi2vec schema")

def create_flickr_schema_manual(client):
    """Create schema for manual vectorization"""
    if client.collections.exists("Flickr30k_manual"):
        print("Flickr30k_manual exists, skipping creation")
        return
    client.collections.create(
        "Flickr30k_manual",
        properties=[
            Property(name="image", data_type=DataType.BLOB),
            Property(name="image_id", data_type=DataType.TEXT),
            Property(name="captions", data_type=DataType.TEXT),
        ],
        vectorizer_config=[Configure.NamedVectors.none(name="image_vector")]
    )
    print("Created Flickr30k_manual schema")

### Import Data 

In [10]:
def import_data_multi2vec(client, data_dir, batch_size=100):
    """Import data with multi2vec-clip"""
    collection = client.collections.get("Flickr30k_multi2vec")
    images_dir = os.path.join(data_dir, "images")
    metadata_dir = os.path.join(data_dir, "metadata")
    metadata_files = [f for f in os.listdir(metadata_dir) if f.endswith('.json')]
    print(f"Importing {len(metadata_files)} images for multi2vec")

    batch_errors = []
    with collection.batch.fixed_size(batch_size=batch_size) as batch:
        for metadata_file in tqdm(metadata_files):
            try:
                with open(os.path.join(metadata_dir, metadata_file), 'r') as f:
                    metadata = json.load(f)
                
                image_path = os.path.join(images_dir, f"image_{metadata['image_id']}.jpg")
                image_data, _ = read_image_data(image_path)
                captions = process_metadata(metadata)
                
                batch.add_object(
                    properties={
                        "image": image_data,
                        "image_id": metadata["image_id"],
                        "captions": captions
                    }
                )
            except Exception as e:
                batch_errors.append(f"Error processing {metadata_file}: {str(e)}")
                continue
    if batch_errors:
        print(f"Encountered {len(batch_errors)} errors during batch import")
        for err in batch_errors[:5]:  # Print first 5 errors
            print(err)
    print("Finished multi2vec import")

def import_data_manual(client, data_dir, batch_size=100):
    """Import data with manual CLIP vectorization"""
    collection = client.collections.get("Flickr30k_manual")
    images_dir = os.path.join(data_dir, "images")
    metadata_dir = os.path.join(data_dir, "metadata")
    metadata_files = [f for f in os.listdir(metadata_dir) if f.endswith('.json')]
    print(f"Importing {len(metadata_files)} images for manual")

    batch_errors = []
    with collection.batch.fixed_size(batch_size=batch_size) as batch:
        for metadata_file in tqdm(metadata_files):
            try:
                with open(os.path.join(metadata_dir, metadata_file), 'r') as f:
                    metadata = json.load(f)
                
                image_path = os.path.join(images_dir, f"image_{metadata['image_id']}.jpg")
                image_data, img = read_image_data(image_path)
                captions = process_metadata(metadata)
                
                image_embedding = clip_model.encode(img, convert_to_numpy=True)
                caption_list = metadata.get("captions", [])
                text_embedding = clip_model.encode(caption_list, convert_to_numpy=True).mean(axis=0) if caption_list else np.zeros_like(image_embedding)
                combined_embedding = 0.7 * image_embedding + 0.3 * text_embedding
                
                batch.add_object(
                    properties={
                        "image": image_data,
                        "image_id": metadata["image_id"],
                        "captions": captions
                    },
                    vector={"image_vector": combined_embedding.tolist()}
                )
                img.close() 
            except Exception as e:
                batch_errors.append(f"Error processing {metadata_file}: {str(e)}")
                continue
    if batch_errors:
        print(f"Encountered {len(batch_errors)} errors during batch import")
        for err in batch_errors[:5]:
            print(err)
    print("Finished manual import")

In [11]:

def main():
    # Connect to Weaviate
    client = weaviate.connect_to_local()

    data_dir = "flickr30k_sample"
    create_flickr_schema_manual(client)
    create_flickr_schema_multi2vec(client)

    import_data_manual(client, data_dir)
    import_data_multi2vec(client, data_dir)

    client.close ()


if __name__ == "__main__":
    main()

Flickr30k_manual exists, skipping creation
Flickr30k_multi2vec exists, skipping creation
Importing 100 images for manual


100%|██████████| 100/100 [00:13<00:00,  7.33it/s]


Finished manual import
Importing 100 images for multi2vec


100%|██████████| 100/100 [00:00<00:00, 2012.27it/s]


Finished multi2vec import
