In [None]:
cd ..

In [2]:
import numpy as np
from pathlib import Path
import logging
from latice.index.faiss_db import (
    FaissLatentVectorDatabase,
    FaissLatentVectorDatabaseConfig,
)
from latice.index.raw_dp_indexer import RawDiffractionPatternIndexer, RawIndexerConfig

logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)

In [None]:
raw_npz_path = "/Users/andrewtung/Documents/Github/latice/notebook/faiss_raw_index.npz"
pattern_path = Path("data/N=100_noised.npy")
angles_path = Path("data/anglefile_N=100.txt")

batch_size = 32
random_seed = 42
image_size = (128, 128)
top_n = 10
orientation_threshold = 3.0

raw_dimension = image_size[0] * image_size[1]
print(f"Raw pattern dimension: {raw_dimension}")

In [None]:
faiss_raw_db_config = FaissLatentVectorDatabaseConfig(
    npz_path=raw_npz_path, dimension=raw_dimension
)
faiss_raw_db = FaissLatentVectorDatabase(config=faiss_raw_db_config)

In [None]:
raw_indexer_config = RawIndexerConfig(
    pattern_path=pattern_path,
    angles_path=angles_path,
    batch_size=batch_size,
    random_seed=random_seed,
    image_size=image_size,
    top_n=top_n,
    orientation_threshold=orientation_threshold,
    db_path=raw_npz_path # Ensure config uses the correct db path
)

raw_indexer = RawDiffractionPatternIndexer(
    config=raw_indexer_config,
    db=faiss_raw_db, # Pass the pre-configured DB
)

In [None]:
# Check if the index needs building by checking the count
if raw_indexer.db.get_count() == 0:
    print(f"Building raw FAISS index at {raw_indexer.db.npz_path}...")
    try:
        raw_indexer.build_dictionary()
        print(f"Index built successfully with {raw_indexer.db.get_count()} raw patterns.")
    except Exception as e:
        print(f"Error building dictionary: {e}")
        # Depending on the error, you might want to raise it or handle it
else:
    print(f"Raw FAISS index already exists at {raw_indexer.db.npz_path} with {raw_indexer.db.get_count()} patterns.")


In [None]:
batch_patterns_np = np.load(pattern_path)[:5]
orientation_results_batch = raw_indexer.index_patterns_batch(
    batch_patterns_np, top_n=1, orientation_threshold=3.0
)

In [None]:
orientation_results_batch