In [2]:
from lib.models.siamese import Siamese
import hydra
from lib.utils import load_config

ckpt_path = "/home/borth/sketch2shape/logs/train/runs/2023-12-02_14-39-32/checkpoints/last.ckpt"
cfg = load_config("train_siamese", overrides=["data=siamese_chair_large","data.drop_last=False"]) 
model = Siamese.load_from_checkpoint(ckpt_path)
datamodule = hydra.utils.instantiate(cfg.data)
datamodule.setup("all")

/home/borth/miniconda3/envs/pytorch3d/lib/python3.9/site-packages/lightning/pytorch/utilities/parsing.py:198: Attribute 'decoder' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['decoder'])`.
/home/borth/miniconda3/envs/pytorch3d/lib/python3.9/site-packages/lightning/pytorch/utilities/parsing.py:198: Attribute 'miner' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['miner'])`.
/home/borth/miniconda3/envs/pytorch3d/lib/python3.9/site-packages/lightning/pytorch/utilities/parsing.py:198: Attribute 'loss' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['loss'])`.


In [3]:
from lightning import Trainer
from pytorch_metric_learning import testers
import torch
import faiss
import numpy as np
from torchmetrics.aggregation import MeanMetric

def get_all_vectors_from_faiss_index(self, index):
    num_vectors = index.ntotal
    all_vectors = np.empty((num_vectors, index.d), dtype=np.float32)
    batch_size = 1000
    for start in range(0, num_vectors, batch_size):
        end = min(start + batch_size, num_vectors)
        vectors_batch = index.reconstruct_n(start, end - start)
        all_vectors[start:end, :] = vectors_batch
    return all_vectors

class SiameseTester(Siamese):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        d = self.decoder.resnet18.fc.out_features
        self.image_index = faiss.IndexFlatL2(d)
        self.sketch_index = faiss.IndexFlatL2(d)
        
        self.index_mode = "image" # image, sketch, all
        self.query_mode = "sketch" # image, sketch, all

        self.mean_distance = MeanMetric()

    @property    
    def image(self):
        return get_all_vectors_from_faiss_index(self.train_image_index)

    @property    
    def sketch(self):
        return get_all_vectors_from_faiss_index(self.train_sketch_index)
    
    @property
    def index(self):
        return self.sketch_index if self.index_mode == "sketch" else self.image_index

    def training_step(self, batch, batch_idx):
        torch.set_grad_enabled(False)
        self.eval()
        index_emb = self.decoder(batch[self.index_mode])
        self.index.add(index_emb.detach().cpu().numpy())

    def on_training_end(self) -> None:
        pass

    def test_step(self, batch, batch_idx):
        query_emb = self.decoder(batch[self.query_mode])
        D, I = self.index.search(query_emb, k=10)
        self.mean_distance.update(D)
        self.log("test/mean_distance", self.mean_distance)

    def on_test_end(self) -> None:
        pass

tester = SiameseTester.load_from_checkpoint(ckpt_path)
trainer = Trainer(max_epochs=1)
trainer.fit(tester, train_dataloaders=[
    datamodule.train_dataloader(),
    datamodule.val_dataloader(),
])
trainer.test(tester, dataloaders=datamodule.val_dataloader())

AttributeError: 'ResNet' object has no attribute 'resnet18'

AssertionError: 

In [19]:
tester.train_image_index
tester.val_image_index

<faiss.swigfaiss_avx2.IndexFlatL2; proxy of <Swig Object of type 'faiss::IndexFlatL2 *' at 0x7fba1318a600> >

In [14]:


    

index = tester.train_image_index
index.add(tester.val_image_index)
tester._get_all_vectors_from_faiss_index(index).shape

AssertionError: 

In [12]:
import numpy as np
x = np.random.randn(32, 128)
tester.image_index.search(x, k=3)

AttributeError: 'tuple' object has no attribute 'shape'

In [15]:
x = np.random.randn(1, 128)
tester.image_index.search(x, k=20)

(array([[114.62485 , 118.6391  , 119.15996 , 119.19931 , 119.35205 ,
         119.6527  , 119.751816, 120.0435  , 120.20599 , 120.37848 ,
         120.38583 , 120.54114 , 120.581345, 120.650406, 120.82707 ,
         120.84895 , 120.90802 , 120.96532 , 120.97449 , 121.00253 ]],
       dtype=float32),
 array([[110760,  81641, 139215,  67727, 153218,  76287, 153231, 120576,
         127798, 101635,  94316, 112844, 149910,   1100,  19147,  17378,
          66755,  47478,  16482, 112841]]))

In [None]:
# precition@k

# recall@k



In [18]:
# Get the total number of vectors in the index
num_vectors = tester.image_index.ntotal

# Specify the range of IDs covering all vectors
ids_to_retrieve = list(range(num_vectors))

# Get all vectors from the index
all_vectors = tester.image_index.reconstruct(ids_to_retrieve)

TypeError: in method 'IndexFlat_reconstruct', argument 2 of type 'faiss::idx_t'

In [28]:
tester.image_index

<faiss.swigfaiss_avx2.IndexFlatL2; proxy of <Swig Object of type 'faiss::IndexFlatL2 *' at 0x7fe460aeda20> >

In [29]:
import faiss

# Assuming tester.image_index is your Faiss index

# Get the total number of vectors in the index
num_vectors = tester.image_index.ntotal

# Specify the range of IDs covering all vectors
ids_to_retrieve = list(range(num_vectors))

# Initialize an empty array to store all vectors
all_vectors = np.empty((num_vectors, tester.image_index.d), dtype=np.float32)

# Retrieve vectors in batches (you can adjust the batch size as needed)
batch_size = 1000

for start in range(0, num_vectors, batch_size):
    end = min(start + batch_size, num_vectors)
    
    # Retrieve vectors in the current batch
    vectors_batch = tester.image_index.reconstruct_n(start, end - start)
    
    # Store the vectors in the result array
    all_vectors[start:end, :] = vectors_batch

In [37]:


image_embs = get_all_vectors_from_faiss_index(tester.image_index)
sketch_embs = get_all_vectors_from_faiss_index(tester.sketch_index)

In [40]:
out = tester.image_index.search(image_embs, k=3)

(array([[1.5258789e-05, 3.5209999e+00, 6.1294518e+00],
        [2.2888184e-05, 6.1474991e+00, 7.6835938e+00],
        [0.0000000e+00, 8.3300247e+00, 1.0208111e+01],
        ...,
        [0.0000000e+00, 1.0393028e+01, 1.0764172e+01],
        [0.0000000e+00, 1.0492321e+01, 1.1005440e+01],
        [0.0000000e+00, 9.2644119e+00, 9.6041336e+00]], dtype=float32),
 array([[     0,      4,     16],
        [     1,      3,     23],
        [     2,     10,  43946],
        ...,
        [204797, 189764, 184665],
        [204798, 204790, 204787],
        [204799,  91167, 204791]]))