In [1]:
import ipfshttpclient
import os
import json
import sys
import tensorflow as tf
import argparse
from PIL import Image
import numpy as np
import hashlib
import faiss
from tqdm import tqdm
from io import BytesIO

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import matplotlib.pyplot as plt
sys.path.append("..")

from demo.image_similarity_keras.model import SiameseModel

In [2]:
client = ipfshttpclient.connect(timeout=300)

In [3]:
physical_devices = tf.config.list_physical_devices('GPU')
try:
    tf.config.experimental.set_memory_growth(physical_devices[0], True)
except:
    # Invalid device or cannot modify virtual devices once initialized.
    pass

In [4]:
model_path = "../demo/models/ConvNext_Large_64b_100ep_final"
augmentation_config = "../demo/configs/default_augmentation.json"

In [5]:
# Load model config
with open(os.path.join(model_path, "configs.json"), "r") as f:
    model_config = json.load(f)

    # Convert to Namespace
    model_config_ns = argparse.Namespace(**model_config)

# Load augmentation config
with open(augmentation_config, "r") as f:
    augmentation_config = json.load(f)

In [6]:
# Convert model_config dictionary to a namespace
model_config_ns = argparse.Namespace(**model_config)

# Get the image_size from model_config or use default value if missing
default_image_size = 224
image_size = model_config.get('image_size', default_image_size)

In [7]:
# Initialize model
model = SiameseModel(**model_config)

# Build and compile model
model.build(False)

# Load weights
model.model.load_weights(os.path.join(model_path, "weights"))

2023-09-11 06:12:10.741919: I tensorflow/core/platform/cpu_feature_guard.cc:152] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE3 SSE4.1 SSE4.2 AVX
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-09-11 06:12:11.206244: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 19406 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 3090, pci bus id: 0000:3b:00.0, compute capability: 8.6


Model: "siamese_ConvNext_Large"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 ConvNext_Large (KerasLayer)  (None, 1536)             196230336 
                                                                 
 dense (Dense)               (None, 512)               786944    
                                                                 
 dense_1 (Dense)             (None, 256)               131328    
                                                                 
 out_emb (Dense)             (None, 128)               32896     
                                                                 
 l2_norm (Lambda)            (None, 128)               0         
                                                                 
Total params: 197,181,504
Trainable params: 951,168
Non-trainable params: 196,230,336
_________________________________________________________________


<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f4da8318a30>

In [8]:
def cid_to_int(cid):
    return int(hashlib.sha256(cid.encode()).hexdigest(), 16) % (2**31 - 1)

In [13]:
index = faiss.IndexFlatL2(128)
index = faiss.IndexIDMap2(index)

In [20]:
def set_nft_vector_database(model, db_index, image_path, batch_size=16):
    ipfs_path = "/" + image_path
    image_list = client.files.ls(ipfs_path)['Entries']
    
    n_images = len(image_list)
    image_array_batch = []
    ids_batch = []

    for i in tqdm(range(n_images), desc="Set NFT vector database"):
        image_cid = client.files.stat(os.path.join(ipfs_path, image_list[i]['Name']))['Hash']
        image_pil = Image.open(os.path.join(image_path, image_list[i]['Name']))
        image_pil = image_pil.convert("RGB")
        image_pil = image_pil.resize((224, 224))
        image_array = tf.keras.preprocessing.image.img_to_array(image_pil) / 255.0
        
        image_array_batch.append(image_array)
        ids_batch.append(cid_to_int(image_cid))

        if len(image_array_batch) == batch_size or i == n_images - 1:
            image_embs_batch = model.predict(np.array(image_array_batch))
            db_index.add_with_ids(np.array(image_embs_batch), np.array(ids_batch))
            image_array_batch = []
            ids_batch = []

In [29]:
def set_nft_vector_database_test(model, db_index, image_path, batch_size=16):
    ipfs_path = "/" + image_path
    image_list = client.files.ls(ipfs_path)['Entries']
    image_list = image_list[7664:]
    
    n_images = len(image_list)
    image_array_batch = []
    ids_batch = []

    for i in tqdm(range(n_images), desc="Set NFT vector database"):
        image_cid = client.files.stat(os.path.join(ipfs_path, image_list[i]['Name']))['Hash']
        image_pil = Image.open(os.path.join(image_path, image_list[i]['Name']))
        image_pil = image_pil.convert("RGB")
        image_pil = image_pil.resize((224, 224))
        image_array = tf.keras.preprocessing.image.img_to_array(image_pil) / 255.0
        
        image_array_batch.append(image_array)
        ids_batch.append(cid_to_int(image_cid))

        if len(image_array_batch) == batch_size or i == n_images - 1:
            image_embs_batch = model.predict(np.array(image_array_batch))
            db_index.add_with_ids(np.array(image_embs_batch), np.array(ids_batch))
            image_array_batch = []
            ids_batch = []

In [15]:
azuki_path = "nft_images/azuki"
bayc_path = "nft_images/bayc"
cryptopunks_path = "nft_images/cryptopunks"

In [16]:
set_nft_vector_database(model, index, azuki_path)

Set NFT vector database: 100%|████████████| 10000/10000 [19:26<00:00,  8.57it/s]


In [21]:
set_nft_vector_database(model, index, bayc_path)

Set NFT vector database: 100%|████████████| 10000/10000 [31:23<00:00,  5.31it/s]


In [22]:
set_nft_vector_database(model, index, cryptopunks_path)

Set NFT vector database:  77%|█████████▉   | 7664/10000 [14:36<04:27,  8.74it/s]


TimeoutError: ReadTimeout: HTTPConnectionPool(host='localhost', port=5001): Read timed out. (read timeout=300)

In [30]:
set_nft_vector_database_test(model, index, cryptopunks_path)

Set NFT vector database: 100%|██████████████| 2336/2336 [00:29<00:00, 79.24it/s]


In [31]:
print(index.ntotal)

30000


In [32]:
faiss.write_index(index, 'base.index')