In [None]:
from pathlib import Path

import matplotlib.pyplot as plt
import torch

from hepattn.models.posenc import (
    FourierPositionEncoder,
    PositionEncoder,
    pos_enc,
    pos_enc_symmetric,
)
from hepattn.utils.visualise_pes import (
    create_simple_pos_enc_visualization,
    create_similarity_matrix_visualization,
)

In [None]:
def test_create_pos_enc_visualizations_basic():
    """Test the comprehoding visualization function with basic data."""
    # Create test data
    num_hits = 1000
    num_queries = 1000
    dim = 128
    hit_phi = 2 * torch.pi * (torch.arange(num_hits) / num_hits - 0.5)
    query_phi = 2 * torch.pi * (torch.arange(num_queries) / num_queries - 0.5)
    out_dir = Path("tests/outputs/posenc")
    out_dir.mkdir(exist_ok=True, parents=True)

    for alpha in [1, 2, 20, 100, 1000]:
        for base in [100, 50000, 100000]:
            hit_posencoder = PositionEncoder(
                input_name="test_hit_input",
                fields=["phi"],
                sym_fields=["phi"],
                dim=dim,
                alpha=alpha,
                base=base,
            )
            query_posencoder = PositionEncoder(
                input_name="test_query_input",
                fields=["phi"],
                sym_fields=["phi"],
                dim=dim,
                alpha=alpha,
                base=base,
            )
            hit_posencs = hit_posencoder({"test_hit_input_phi": hit_phi})
            query_posencs = query_posencoder({"test_query_input_phi": query_phi})
            create_simple_pos_enc_visualization(hit_posencs, save_path=f"{out_dir}/hit_pe_alpha{alpha}_base{base}.jpeg")
            create_simple_pos_enc_visualization(query_posencs, save_path=f"{out_dir}/query_pe_alpha{alpha}_base{base}.jpeg")
            create_similarity_matrix_visualization(
                hit_posencs,
                query_posencs,
                "Hit PE - Query PE Similarity",
                f"{out_dir}/dot_prod_alpha{alpha}_base{base}.jpeg",
            )