In [1]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.models as models
from PIL import Image
import numpy as np
import os


In [2]:

# Load the pre-trained ResNet-50 model
model = models.resnet50(pretrained=True)
model = nn.Sequential(*list(model.children())[:-1])  # Remove the classification layer

# Define preprocessing transforms
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Function to get image embeddings
def get_image_embeddings(image_path):
    img = Image.open(image_path).convert('RGB')
    img = preprocess(img)
    img = img.unsqueeze(0)  # Add batch dimension
    with torch.no_grad():
        embeddings = model(img)
    return embeddings.squeeze().numpy()



In [3]:
from neo4j import GraphDatabase

# Connect to the Neo4j database
uri = "bolt://localhost:7687"
username = "neo4j"
password = "password"

class Neo4jDatabase:
    def __init__(self):
        self._driver = GraphDatabase.driver(uri, auth=(username, password))

    def close(self):
        self._driver.close()

    def create_image_node(self, embeddings,path):
        with self._driver.session() as session:
            session.write_transaction(self._create_image_node, embeddings,path)

    @staticmethod
    def _create_image_node(tx, embeddings,path):
        query = (
            "CREATE (img:Image {embeddings: $embeddings,path: $path})"
        )
        tx.run(query, embeddings=embeddings)

# # Usage
# image_path = 'path_to_your_image.jpg'
# embeddings = get_image_embeddings(image_path)

# neo4j_db = Neo4jDatabase()
# neo4j_db.create_image_node(embeddings)


In [4]:
image_paths = os.listdir('data/images')

# filter only jpg
image_paths = [image_path for image_path in image_paths if image_path.endswith('.jpg')]

In [5]:
image_paths = [os.path.join('data/images', path) for path in image_paths]

In [6]:

neo4j_db = Neo4jDatabase()
# Iterate through image paths and insert into Neo4j
for image_path in image_paths:
    embeddings = get_image_embeddings(image_path)
    neo4j_db.create_image_node(embeddings,image_path)

  session.write_transaction(self._create_image_node, embeddings)
  session.write_transaction(self._create_image_node, embeddings)
  session.write_transaction(self._create_image_node, embeddings)
  session.write_transaction(self._create_image_node, embeddings)
  session.write_transaction(self._create_image_node, embeddings)
  session.write_transaction(self._create_image_node, embeddings)
  session.write_transaction(self._create_image_node, embeddings)
  session.write_transaction(self._create_image_node, embeddings)
  session.write_transaction(self._create_image_node, embeddings)
  session.write_transaction(self._create_image_node, embeddings)
  session.write_transaction(self._create_image_node, embeddings)
  session.write_transaction(self._create_image_node, embeddings)
  session.write_transaction(self._create_image_node, embeddings)
  session.write_transaction(self._create_image_node, embeddings)
  session.write_transaction(self._create_image_node, embeddings)
  session.write_transacti