In [1]:
%load_ext autoreload
%autoreload 2

import numpy as np
from pathlib import Path
import torch
from latice.index.faiss_db import (
    FaissLatentVectorDatabase,
    FaissLatentVectorDatabaseConfig,
)
from latice.model import VariationalAutoEncoderRawData
from latice.index.dp_indexer import DiffractionPatternIndexer, IndexerConfig


# Configuration

In [2]:
npz_path = "faiss_index.npz"
pattern_path = Path("../data/N=100_noised.npy")
angles_path = Path("../data/anglefile_N=100.txt")
dimension = 16
device = "mps"
model_path = "../checkpoints/vae-best.pt"
batch_size = 32
random_seed = 42
image_size = (128, 128)
top_n = 10
orientation_threshold = 3.0

# Initialise the FAISS vector database

In [None]:
faiss_db = FaissLatentVectorDatabase(
    config=FaissLatentVectorDatabaseConfig(npz_path=npz_path, dimension=dimension)
)

# Initialise the VAE model

In [None]:
model = VariationalAutoEncoderRawData()
checkpoint = torch.load(model_path, map_location=device)
model.load_state_dict(checkpoint)
model.eval()

# Initialise the indexer

In [None]:
indexer = DiffractionPatternIndexer(
    model=model,
    db=faiss_db,
    config=IndexerConfig(
        pattern_path=pattern_path,
        angles_path=angles_path,
        batch_size=batch_size,
        device=device,
        latent_dim=dimension,
        random_seed=random_seed,
        image_size=image_size,
        top_n=top_n,
        orientation_threshold=orientation_threshold,
    ),
)

# Build the FAISS dictionary (index)

In [None]:
indexer.build_dictionary()

# Index a single pattern

In [None]:
sample_pattern, angles = indexer._create_dataloader.dataset[0]
sample_pattern = sample_pattern.squeeze()
orientation_result = indexer.index_pattern(
    pattern=sample_pattern, top_n=20, orientation_threshold=3.0
)
print(f"True angles: {angles}")
print(f"Best orientation: {orientation_result.get_top_n_orientations(10)}")
print(f"Success: {orientation_result.success}")


# Batch indexing

In [None]:
batch_patterns = np.load(pattern_path)[:5]
orientation_results = indexer.index_patterns_batch(batch_patterns)
for i, result in enumerate(orientation_results):
    print(f"Pattern {i}: {result.mean_orientation} (success: {result.success})")