In [None]:
%pylab inline
import os, sys

sys.path.append("../..")
import numpy as np
import h5py
from tqdm import tqdm

from astroclip.env import format_with_env
from plotting import plot_similar_images, plot_similar_spectra

%pylab inline
import pyarrow
pyarrow.PyExtensionType.set_auto_load(True)
from tqdm import tqdm
import networkx as nx
from pyvis.network import Network
from IPython.display import HTML
import base64
from io import BytesIO
import matplotlib.pyplot as plt
from sklearn.neighbors import kneighbors_graph
from sklearn.preprocessing import normalize
import torch
import numpy as np
import matplotlib.pyplot as plt
import numpy as np
from ipywidgets import interact
from IPython.display import display

ASTROCLIP_ROOT = format_with_env("{ASTROCLIP_ROOT}")

# Load the embeddings
embedding_loc = "/workspace/astroclip/data/astroclip_embeddings.hdf5"
with h5py.File(embedding_loc, "r") as f:
    images = f["image"][:]
    spectra = f["spectrum"][:]
    im_embeddings = f["image_embeddings"][:]
    sp_embeddings = f["spectrum_embeddings"][:]
    obj_ids = f["object_id"][:]

# Normalize the embeddings
image_features_normed = im_embeddings / np.linalg.norm(
    im_embeddings, axis=-1, keepdims=True
)
spectrum_features_normed = sp_embeddings / np.linalg.norm(
    sp_embeddings, axis=-1, keepdims=True
)

In [None]:
# Look at some randomly selected galaxies
figure(figsize=[15, 15])
for i in range(15):
    for j in range(15):
        subplot(15, 15, i * 15 + j + 1)
        imshow(images[i * 15 + j + 1000].T)
        title(i * 15 + j + 1000)
        axis("off")
plt.subplots_adjust(wspace=0.1, hspace=0.11)

# Plot retrieved galaxy images

In [None]:
# Choose some galaxies to search for similar galaxies
ind_query = [7, 354, 526, 300]

# Find the indices of the galaxies in the dataset
im_sims = []

for ind in ind_query:
    # Compute the similarity between the query galaxy and all other galaxies
    sp_sim = spectrum_features_normed[ind] @ spectrum_features_normed.T
    im_sim = image_features_normed[ind] @ image_features_normed.T
    x_im_sim = image_features_normed[ind] @ spectrum_features_normed.T
    x_sp_sim = spectrum_features_normed[ind] @ image_features_normed.T

    # Find the 8 most similar galaxies (images)
    im_sims.append(
        {
            "sp_sim": [images[i] for i in argsort(sp_sim)[::-1][:8]],
            "im_sim": [images[i] for i in argsort(im_sim)[::-1][:8]],
            "x_im_sim": [images[i] for i in argsort(x_im_sim)[::-1][:8]],
            "x_sp_sim": [images[i] for i in argsort(x_sp_sim)[::-1][:8]],
        }
    )

In [None]:
# Image-image similarity
plot_similar_images(
    [images[i] for i in ind_query],
    im_sims,
    similarity_type="im_sim",
    num_retrievals=8,
    save_dir="../outputs/image_retrieval/",
)

In [None]:
# Spectrum-spectrum similarity
plot_similar_images(
    [images[i] for i in ind_query],
    im_sims,
    similarity_type="sp_sim",
    num_retrievals=8,
    save_dir="../outputs/image_retrieval/",
)

In [None]:
# Image-spectrum similarity
plot_similar_images(
    [images[i] for i in ind_query],
    im_sims,
    similarity_type="x_im_sim",
    num_retrievals=8,
    save_dir="../outputs/image_retrieval/",
)

In [None]:
# Spectrum-image similarity
plot_similar_images(
    [images[i] for i in ind_query],
    im_sims,
    similarity_type="x_sp_sim",
    num_retrievals=8,
    save_dir="../outputs/image_retrieval/",
)

# Plot retrieved galaxy spectra

In [None]:
# Choose some galaxies to search for similar galaxies
ind_query = [7, 77]

# Find the indices of the galaxies in the dataset
sp_sims = []

for ind in ind_query:
    # Compute the similarity between the query galaxy and all other galaxies
    sp_sim = spectrum_features_normed[ind] @ spectrum_features_normed.T
    im_sim = image_features_normed[ind] @ image_features_normed.T
    x_im_sim = image_features_normed[ind] @ spectrum_features_normed.T
    x_sp_sim = spectrum_features_normed[ind] @ image_features_normed.T

    # Find the 8 most similar galaxies (images)
    sp_sims.append(
        {
            "sp_sim": [spectra[i] for i in argsort(sp_sim)[::-1][:8]],
            "im_sim": [spectra[i] for i in argsort(im_sim)[::-1][:8]],
            "x_im_sim": [spectra[i] for i in argsort(x_im_sim)[::-1][:8]],
            "x_sp_sim": [spectra[i] for i in argsort(x_sp_sim)[::-1][:8]],
        }
    )

In [None]:
# Image-image similarity
plot_similar_spectra(
    [spectra[i] for i in ind_query],
    [images[i] for i in ind_query],
    sp_sims,
    similarity_type="im_sim",
    save_dir="./outputs/spectrum_retrieval/",
)

In [None]:
# Spectrum-spectrum similarity
plot_similar_spectra(
    [spectra[i] for i in ind_query],
    [images[i] for i in ind_query],
    sp_sims,
    similarity_type="sp_sim",
    save_dir="./outputs/spectrum_retrieval/",
)

In [None]:
# Image-spectrum similarity
plot_similar_spectra(
    [spectra[i] for i in ind_query],
    [images[i] for i in ind_query],
    sp_sims,
    similarity_type="x_im_sim",
    save_dir="./outputs/spectrum_retrieval/",
)

In [None]:
# Spectrum-image similarity
plot_similar_spectra(
    [spectra[i] for i in ind_query],
    [images[i] for i in ind_query],
    sp_sims,
    similarity_type="x_sp_sim",
    save_dir="./outputs/spectrum_retrieval/",
)

In [None]:
def graph_from_embeddings(embeddings, images, k=5, symmetrize=True):
    """
    Builds a k-NN graph adjacency matrix using scikit-learn (cosine distance).

    Returns a sparse adjacency matrix (numpy array or CSR).
    """
    # Convert to NumPy and normalize (cosine similarity = dot of L2-normalized vectors)
    emb_np = embeddings.cpu().numpy()

    # scikit-learn uses cosine *distance*, so we negate similarity in behavior
    A = kneighbors_graph(emb_np, n_neighbors=k, metric='cosine', mode='connectivity', include_self=False)

    if symmetrize:
        A = A.maximum(A.T)  # make it symmetric (undirected)

    G = nx.Graph()
    A = A.tocoo()

    for i in range(A.shape[0]):
        G.add_node(i, image_tensor=images[i].permute(1, 2, 0).cpu().numpy() if images is not None else None)

    for i, j, w in zip(A.row, A.col, A.data):
        G.add_edge(i, j, weight=w)

    return G

def encode_image_base64(image_array, scale=1.):
    fig = plt.figure(figsize=(scale, scale), dpi=100)
    plt.axis("off")
    plt.imshow(image_array)
    buf = BytesIO()
    plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
    plt.close(fig)
    buf.seek(0)
    img_base64 = base64.b64encode(buf.read()).decode('utf-8')
    return f"data:image/png;base64,{img_base64}"

def draw_interactive_graph_colab(graph):
    net = Network(height='750px', width='100%', bgcolor='#000000', font_color='white', notebook=False, cdn_resources='remote')

    # Encode each image as base64 and use as node icon
    for node_id, data in graph.nodes(data=True):
        img_url = encode_image_base64(data['image_tensor'])
        net.add_node(
            int(node_id),
            shape="image",
            image=img_url,
            title=f"Galaxy {node_id}",
            size=50
        )

    for src, tgt, data in graph.edges(data=True):
        net.add_edge(int(src), int(tgt), value=float(data['weight']) * 5,
                     color='rgba(255, 255, 255, 0.5)')

    # Generate the HTML string (without writing to file)
    html_str = net.generate_html()

    # Display directly in Colab
    return HTML(html_str)
