In [1]:
%load_ext autoreload
%autoreload 2
%cd ..

import numpy as np
from pathlib import Path
import torch
from latice.index.faiss_db import (
    FaissLatentVectorDatabase,
    FaissLatentVectorDatabaseConfig,
)
from latice.model import VariationalAutoEncoderRawData
from latice.index.dp_indexer import DiffractionPatternIndexer, IndexerConfig


/Users/andrewtung/Documents/Github/latice


# Configuration

In [3]:
npz_path = "faiss_index.npz"
pattern_path = Path("data/N=100_noised.npy")
angles_path = Path("data/anglefile_N=100.txt")
dimension = 16
device = "mps"
model_path = "checkpoints/vae-best.pt"
batch_size = 32
random_seed = 42
image_size = (128, 128)
top_n = 1
orientation_threshold = 1.0

# Initialise the FAISS vector database

In [4]:
faiss_db = FaissLatentVectorDatabase(
    config=FaissLatentVectorDatabaseConfig(npz_path=npz_path, dimension=dimension)
)

2025-04-27 00:03:43,473 - latice.index.faiss_db - INFO - Index file (faiss_index.index) or orientations file (faiss_index.orient.npz) not found. Creating a new index.


# Initialise the VAE model

In [5]:
model = VariationalAutoEncoderRawData()
checkpoint = torch.load(model_path, map_location=device)
model.load_state_dict(checkpoint)
model.eval()

VariationalAutoEncoderRawData(
  (encoder): Sequential(
    (0): Sequential(
      (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      (2): LeakyReLU(negative_slope=0.02)
    )
    (1): Sequential(
      (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      (2): LeakyReLU(negative_slope=0.02)
    )
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Sequential(
      (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      (2): LeakyReLU(negative_slope=0.02)
    )
    (4): Sequential(
      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): InstanceNorm2d(64, eps=1e-

# Initialise the indexer

In [6]:
indexer = DiffractionPatternIndexer(
    model=model,
    db=faiss_db,
    config=IndexerConfig(
        pattern_path=pattern_path,
        angles_path=angles_path,
        batch_size=batch_size,
        device=device,
        latent_dim=dimension,
        random_seed=random_seed,
        image_size=image_size,
        top_n=top_n,
        orientation_threshold=orientation_threshold,
    ),
)

2025-04-27 00:03:50,381 - latice.index.dp_indexer - INFO - Using device: mps


# Build the FAISS dictionary (index)

In [7]:
indexer.build_dictionary()

2025-04-27 00:03:58,698 - latice.data_module - INFO - Loaded diffraction pattern data from data/N=100_noised.npy
2025-04-27 00:03:59,370 - latice.data_module - INFO - Dataset initialized with 333227 samples
2025-04-27 00:03:59,371 - latice.data_module - INFO - Test dataset prepared with 333227 samples
2025-04-27 00:03:59,382 - latice.index.dp_indexer - INFO - Generating latent vectors from patterns in data/N=100_noised.npy


Output()

2025-04-27 00:14:49,312 - latice.index.dp_indexer - INFO - Adding 333227 vectors to database


# Index a single pattern

In [None]:
sample_pattern, angles = indexer._create_dataloader.dataset[0]
sample_pattern = sample_pattern.squeeze()
orientation_result = indexer.index_pattern(
    pattern=sample_pattern, top_n=20, orientation_threshold=3.0
)
print(f"True angles: {angles}")
print(f"Best orientation: {orientation_result.get_top_n_orientations(10)}")
print(f"Success: {orientation_result.success}")


# Batch indexing

In [None]:
batch_patterns = np.load(pattern_path)[:5]
orientation_results = indexer.index_patterns_batch(batch_patterns)
for i, result in enumerate(orientation_results):
    print(f"Pattern {i}: {result.mean_orientation} (success: {result.success})")

In [8]:
import time
from rich.progress import (
    Progress,
    SpinnerColumn,
    TimeElapsedColumn,
    BarColumn,
    TextColumn,
)

batch_patterns_np = np.load(pattern_path)
num_rows = len(batch_patterns_np)
n_samples = 1000
random_indices = np.random.choice(num_rows, size=n_samples, replace=False)
sampled_arr = batch_patterns_np[random_indices]
all_index_times = []

with Progress(
    SpinnerColumn(),
    TextColumn("[progress.description]{task.description}"),
    BarColumn(),
    TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
    TimeElapsedColumn(),
) as progress:
    task = progress.add_task("[cyan]Indexing patterns...", total=n_samples)
    for i in random_indices:
        start_time = time.time()
        orientation_results_batch = indexer.index_pattern(
            batch_patterns_np[i], top_n=1, orientation_threshold=1.0
        )
        end_time = time.time()
        index_time = end_time - start_time
        all_index_times.append(index_time)
        progress.update(task, advance=1, time_elapsed=index_time)

print(
    f"Indexed {len(batch_patterns_np)} patterns in {end_time - start_time:.4f} seconds"
)
mean_index_time = np.mean(all_index_times)
std_index_time = np.std(all_index_times)
print(f"Average time per pattern: {mean_index_time:.4f} seconds")
print(f"Standard deviation of time per pattern: {std_index_time:.4f} seconds")


Output()

Indexed 333227 patterns in 0.0125 seconds
Average time per pattern: 0.0166 seconds
Standard deviation of time per pattern: 0.0874 seconds
