In [1]:
%load_ext autoreload
%autoreload 2

import numpy as np
from pathlib import Path
import torch
from latice.index.faiss_db import (
    FaissLatentVectorDatabase,
    FaissLatentVectorDatabaseConfig,
)
from latice.model import VariationalAutoEncoderRawData
from latice.index.dp_indexer import DiffractionPatternIndexer, IndexerConfig


# Configuration

In [2]:
npz_path = "faiss_index.npz"
pattern_path = Path("../data/N=100_noised.npy")
angles_path = Path("../data/anglefile_N=100.txt")
dimension = 16
device = "mps"
model_path = "../checkpoints/vae-best.pt"
batch_size = 32
random_seed = 42
image_size = (128, 128)
top_n = 10
orientation_threshold = 3.0

# Initialise the FAISS vector database

In [3]:
faiss_db = FaissLatentVectorDatabase(
    config=FaissLatentVectorDatabaseConfig(npz_path=npz_path, dimension=dimension)
)

2025-04-24 22:58:21,564 - latice.index.faiss_db - INFO - No existing index found at faiss_index.npz. Creating a new one.


# Initialise the VAE model

In [4]:
model = VariationalAutoEncoderRawData()
checkpoint = torch.load(model_path, map_location=device)
model.load_state_dict(checkpoint)
model.eval()

VariationalAutoEncoderRawData(
  (encoder): Sequential(
    (0): Sequential(
      (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      (2): LeakyReLU(negative_slope=0.02)
    )
    (1): Sequential(
      (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      (2): LeakyReLU(negative_slope=0.02)
    )
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Sequential(
      (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      (2): LeakyReLU(negative_slope=0.02)
    )
    (4): Sequential(
      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): InstanceNorm2d(64, eps=1e-

# Initialise the indexer

In [None]:
indexer = DiffractionPatternIndexer(
    model=model,
    db=faiss_db,
    config=IndexerConfig(
        pattern_path=pattern_path,
        angles_path=angles_path,
        batch_size=batch_size,
        device=device,
        latent_dim=dimension,
        random_seed=random_seed,
        image_size=image_size,
        top_n=top_n,
        orientation_threshold=orientation_threshold,
    ),
)

# Build the FAISS dictionary (index)

In [None]:
indexer.build_dictionary()

# Index a single pattern

In [None]:
sample_pattern, angles = indexer._create_dataloader.dataset[0]
sample_pattern = sample_pattern.squeeze()
orientation_result = indexer.index_pattern(
    pattern=sample_pattern, top_n=20, orientation_threshold=3.0
)
print(f"True angles: {angles}")
print(f"Best orientation: {orientation_result.get_top_n_orientations(10)}")
print(f"Success: {orientation_result.success}")


# Batch indexing

In [None]:
batch_patterns = np.load(pattern_path)[:5]
orientation_results = indexer.index_patterns_batch(batch_patterns)
for i, result in enumerate(orientation_results):
    print(f"Pattern {i}: {result.mean_orientation} (success: {result.success})")