# An AI-Powered Light Curve Similarity Search

## Learning Goals

In this notebook, we will build and train a Convolutional Neural Network (CNN) classifier on transforms of variable star light curves from [TESS-SPOC](https://archive.stsci.edu/hlsp/tess-spoc), then extract embedding vectors to build a light curve similarity search database. 

## Introduction

Conventional light curve data searches are conducted based on metadata---stellar parameters like effective temperature or radius, light curve measurements like amplitude or photometric precision, etc. However, interesting classes of variables don't always fit nicely into these metadata boxes, making class members difficult to query without expensive light curve analyses. For example, eclipsing binary light curves can come from any type of star system, regardless of their stellar parameters. Queries for these types of targets must be done on the features of the light curves themselves.

The enormous wealth of time series photometry from space missions like Kepler, K2, TESS, and soon Roman presents a major challenge to astrophysics investigations that must analyze light curves individually to identify science samples. Analyzing the ~24 million light curves from the first two years of TESS alone, without high performance or multicore computing, would take over nine months at a single second per light curve. Faster, more intelligent algorithms are necessary for bulk analysis in a reasonable timeframe. 

This notebook presents the end-to-end development of a light curve similarity search database powered by machine learning. In such a fast-searching vector database, users can choose a target or upload a light curve and quickly retrieve a desired number of “similar” targets, where similarity is determined against an extracted layer of a trained neural network called a "feature embedding" or "embedding vector." For our light curves, we will use the MAST High Level Science Product [TESS-SPOC](https://archive.stsci.edu/hlsp/tess-spoc), which consists of full-framge image light curves from the Transiting Exoplanet Survey Satellite (TESS) processed by the Science Processing Operations Center (SPOC) pipeline. In its search for transiting exoplanets, TESS measures light curves for millions of stars, which are useful for identifying stellar astrophysical phenomena like rotation, flares, and binary eclipses.

In the similarity search database, each light curve has an associated embedding vector, which consists of some number (in this case 64) floating point values. This means that the database is relatively small compared to the size of a full light curve database. To find light curves similar to a desired target, the target embedding vector is compared to all of the vectors in the database using the euclidean distance, which is also a fast and lightweight computation.

To build the similarity search, we must do the following:
1. Choose a training set
2. Preprocess the training data
3. Build a CNN classifier
4. Train the CNN
5. Evaluate the CNN performance
6. Extract the embedding vectors and demonstrate a similarity search

## Dependencies

This notebook uses the following packages:
- `numpy` for general numeric operations
- `pandas` to contain data in DataFrames
- `scipy` for its Fast Fourier Transform library
- `matplotlib.pyplot` for plotting
- `lightkurve` for interacting with light curve data
- `dask` for parallelization
- `sklearn` to build a confusion matrix
- `seaborn` to plot a heatmap
- `torch` for building, training, and evaluating the CNN

If you do not have these packages installed, you can install them using `pip` or `conda`.

## About this Notebook

This notebook was written by Zach Claytor, an Astronomical Data Scientist at STScI.

Contact: zclaytor@stsci.edu

In [None]:
import os
from copy import deepcopy

import numpy as np
import pandas as pd
from scipy import fft
import matplotlib.pyplot as plt
import lightkurve as lk
import dask
from sklearn.metrics import confusion_matrix
import seaborn

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau

## 0. Set configuration parameters

The cell below contains parameters that the user can use to configure the notebook run. First and foremost, the notebook contains two large computations: processing light curves into wavelet power spectra, and training the CNN. Both of these steps are precomputed, so they are not re-run by default. However, if you want to change any configurations and re-run them, you can by setting either of the first two variables to `True`.

Other parameters include labels for the CNN run, which you can change so as not to overwrite existing save files, as well as CNN configurations: training batch size, early stopping patience, and the number of epochs to train for. 

The CNN exhibits a configurable number of convolution filters in each layer. In theory, more filters let you extract higher order features, but it can also result in overfitting. There are pre-defined configurations in the cells below (see the `channels` dict), where the configuration is set by `RUN_NUMBER`. We use the simplest configuration by default.

`DEVICE` is the device on which to train the CNN. By default, it's the CPU, but if you have access to a CUDA-enabled GPU, you could swap "cpu" for "cuda".

`LOSS_FUNCTION` is Cross Entropy by default, but you can also experiment with other loss functions. See [the PyTorch documentation](https://docs.pytorch.org/docs/stable/nn.html#loss-functions) for the possibilities.

In [None]:
# Edit this cell to modify the CNN run configuration

# Choose whether to re-run computationally expensive steps
RECOMPUTE_WAVELETS = False
RETRAIN_CNN = False

# Choose the run name and CNN configuration to use
RUN_NAME = "cnn"  # Label for this run
RUN_NUMBER = 0  # Which CNN channels configuration to use

# IF RETRAINING
BATCH_SIZE = 50  # Number of training examples to use per batch
PATIENCE = 30  # Number of epochs to wait for loss to decrease before stopping
MAX_EPOCHS = 500  # Maximum number of epochs to train before stopping
TRAINING_COLS = ["ebs", "exo", "flares", "rot"]  # Which columns to use for training
DEVICE = torch.device("cpu")  # Which device to train the CNN on
LOSS_FUNCTION = torch.nn.CrossEntropyLoss()  # Loss function to use
TRAIN_SPLIT = [0.8, 0.1, 0.1]  # training/validation/test split

# `model_name` becomes the base of the output file names.
# It is {RUN_NAME}_{RUN_NUMBER} by default, but you can give it a custom label if desired.
model_name = f"{RUN_NAME}_{RUN_NUMBER}" 

In [None]:
# This cell uses the configurations above to do some basic setup. You can leave it alone.

# Channel configurations
channels = {
    0: [8, 16, 32],
    1: [16, 32, 64],
    2: [32, 64, 128],
    3: [64, 128, 256]
}

selected_channels = channels[RUN_NUMBER]

# Normalize training split to add to 1
TRAIN_SPLIT = np.array(TRAIN_SPLIT)/sum(TRAIN_SPLIT)

# Create output directory if it doesn't exist
os.makedirs(f"runs/{model_name}", exist_ok=True)

## 1. Choose Training Set

For our training set, we will use TESS-SPOC light curves of four types of objects, taken from the references below:
- Eclipsing Binary Stars ("ebs"), [Prša et al. (2022)](https://ui.adsabs.harvard.edu/abs/2022ApJS..258...16P/abstract)
- Exoplanet Transit Hosts ("exo"), [Exoplanet Follow-up Observing Program](https://exofop.ipac.caltech.edu/tess), queried 2025-05-05
- Flaring Stars ("flares") [Günther et al. (2020)](https://ui.adsabs.harvard.edu/abs/2020AJ....159...60G/abstract)
- Rotating Stars ("rot") [Kounkel et al. (2022)](https://ui.adsabs.harvard.edu/abs/2022AJ....164..137K/abstract)
<!-- - Asteroseismic Oscillators ("seis") -->

The data tables denoting the targets and sectors for which each event was detected are nonuniform, and conducting those queries is beyond the scope of this notebook. Instead, we've already undertaken the task of querying MAST for the light curves associated with the targets and sectors listed in these references. The files in the `data` folder have the cloud URIs so that we can read the light curves directly from the cloud without having to download them.

In [None]:
def read_uris(dataset, n_rows=1000):
    # reads URIs from file, given dataset name
    filepath = f"data/tess-spoc-{dataset}-uris.txt"
    return np.loadtxt(filepath, dtype=str, max_rows=n_rows)


def extract_meta(uri):
    # converts URI into TIC ID, Sector pair
    tic_id, sector = uri.split("_")[4].split("-")
    tic_id = int(tic_id)
    sector = int(sector[1:])
    return tic_id, sector


# Read the light curve URIs from file
uris = {x: read_uris(x, n_rows=1000) for x in TRAINING_COLS}

# Flatten URI dict and convert it to DataFrame.
# We'll use this later.
flat_uris = np.concatenate(list(uris.values()))
uri_df = pd.Series({extract_meta(uri): uri for uri in flat_uris})
uri_df = uri_df.to_frame(name="uri").drop_duplicates()
uri_df.index.names = ["TIC", "sector"]
uri_df

## 2. Preprocess Training Data

Rather than using the raw light curves, we will leverage a 2D representation of the data that includes frequency information: the Morlet wavelet transform. This representation comes with several advantages:

- Frequency transforms naturally concentrate frequency information spatially, which helps CNNs to interpret them.
- The wavelet transform is a 2D representation, which lets us leverage image recognition algorithms.
- Wavelet transforms can be binned down to small sizes (e.g., 64x64 pixels) without losing the most useful information.

According to [Torrence and Compo (1998)](https://ui.adsabs.harvard.edu/abs/1998BAMS...79...61T/abstract), "The continuous wavelet transform of a discrete sequence $x_n$ is defined as the convolution of $x_n$ with a scaled and translated version of [the mother wavelet] $\psi_0 (\eta)$

$W_n(s) = \sum_{n^\prime = 0}^{N-1} x_{n^\prime} \psi^*\left[\frac{(n^\prime - n)\delta t}{s}\right]$,

and the Morlet wavelet as

$\psi_0 (\eta) = \pi^{-1/4} \exp{(i \omega_0 \eta - \eta^2 /2)}$.

According to the convolution theorem, we can do the equivalent operation using the discrete Fourier transform. Here we will define the continuous wavelet transform as the inverse fast-Fourier transform (IFFT) of the product of the FFTs of the signal and conjugated wavelet.

For more information on wavelet methods and implementations, see [Torrence and Compo (1998)](https://ui.adsabs.harvard.edu/abs/1998BAMS...79...61T/abstract) and the [`pycwt` documentation](https://pycwt.readthedocs.io).

### 2a. Define processing functions and show an example

In [None]:
def morlet_wavelet_ft(w, w0=6):
    """Fourier transform of the Morlet wavelet
    """
    return np.pi**(-0.25) * np.exp(-0.5 * (w - w0)**2)
    
    
def cwt_morlet(signal, time, freqs, w0=6, pad=True):
    """Morlet wavelet transform as defined by, e.g., Torrence & Compo 1998.

    Arguments:
      signal (numpy array): the signal to be transformed
      time (numpy array): the time array
      freq (numpy array): the frequencies at which to evaluate the transform
      w0 (float, default=6): the nondimensional frequency (tweaking this
          affects the resolution at small versus large frequency scales.)
      pad (bool, default=True); whether to zero pad the signal to the next
          power of two length.

    Returns:
      cwt_norm (complex128 numpy array): the continuous wavelet transform,
          normalized by scales to preserve energy conservation
    """
    sig_norm = signal - np.mean(signal)
    scales = (w0 + np.sqrt(2 + w0**2)) / (4*np.pi*freqs)
    N = N_orig = len(time)
    dt = time[1] - time[0]
    
    if pad:
        next_pow_2 = int(2**np.ceil(np.log2(N)))
        sig_norm = np.pad(sig_norm, (0, next_pow_2-N))
        N = next_pow_2

    # Compute signal FFT
    signal_ft = fft.fft(sig_norm, n=N)

    # Fourier angular frequencies
    ftfreqs = 2*np.pi*fft.fftfreq(N, dt)
    # Set up wavelet grid for each scale
    psi_ft = np.sqrt(2*np.pi*scales[:, None]/dt) * morlet_wavelet_ft(scales[:, None] * ftfreqs).conj()
    # Compute IFFT to produce the wavelet transform
    cwtmat = fft.ifft(signal_ft * psi_ft, axis=1)

    cwt_scaled = cwtmat[:, :N_orig]/scales[:, None]**0.5
    return cwt_scaled


def reshape_power(power, output_size=64):
    """Use PyTorch's tensor operations to bin the power spectrum
    """
    if isinstance(output_size, int):
        output_size = (output_size, output_size)
        
    power_tensor = torch.tensor(np.expand_dims(power, 0))
    
    power = torch.nn.functional.adaptive_avg_pool2d(power_tensor, output_size=output_size)
    return np.squeeze(power.numpy())

Now let's see an example. For the first rotating star light curve, we'll look at
1. the light curve
2. the wavelet transform
3. the binned wavelet transform

In [None]:
uri = uris["rot"][0]
lc = lk.read(uri)

# Plot the light curve
lc.plot()

# Remove NaNs and fill gaps
lc = lc.remove_nans().fill_gaps()
time = lc.time.value
flux = lc.flux.value

# Compute and plot the wavelet transform
freq = np.geomspace(1/10, 10, 512)
power = cwt_morlet(flux, lc.time.value, freq)

plt.figure()
im = plt.pcolormesh(time, 1/freq, np.abs(power), shading="nearest")
plt.yscale("log")
plt.gca().invert_yaxis()
plt.xlabel("Time (days)")
plt.ylabel("Period (days)")
plt.colorbar(im, label="Power (arbitrary units)")

# Bin the wavelet transform and plot it
binned_power = reshape_power(np.abs(power))
plt.figure()
plt.imshow(binned_power, origin="lower")
plt.axis("off")

The light curve shows a strong signal that repeats slightly faster than once per day. Taking the wavelet transform, this signal results in a horizontal band of power at around 0.7 day, and a slightly weaker band at 1.4 day. There is a gap halfway through the time axis corresponding to a gap in the light curve when the telescope briefly stopped taking data in order to downlink to Earth. Binning the wavelet transform, all this information is still preserved, but the image is much smaller in memory.

### 2b. Process all light curves

We will loop over all the light curves to generate and save wavelet power spectra. We can accelerate this process using parallelization with `dask`.

In [None]:
# Define function to process one light curve
def process_lc(uri, class_label, flux_column="PDCSAP_FLUX", quality_bitmask="default"):
    lc = lk.read(
        uri, # the path to the light curve in the cloud
        # flux_column=flux_column, # which column to set as the flux
        # quality_bitmask=quality_bitmask, # remove bad data
    )
    
    # Remove NaNs and fill gaps
    lc = lc.remove_nans().fill_gaps()
    
    # Compute the wavelet transform
    freq = np.geomspace(1/10, 12, 512)
    power = cwt_morlet(lc.flux.value, lc.time.value, freq)
    
    binned_power = reshape_power(np.abs(power)) # bin the power spectrum
    binned_power -= binned_power.min() # normalize the power spectrum to [0, 255]
    binned_power *= 255/binned_power.max()
    binned_power = binned_power.astype(np.uint8) # convert to 8-bit integers
    
    # Save normalized power spectrum
    basename = os.path.basename(uri).replace("lc.fits", "wps.npy")
    np.save(os.path.join("data", class_label, basename), binned_power)

    print(os.path.join(class_label, basename), "done.")
    return binned_power

In [None]:
if RECOMPUTE_WAVELETS:
    # Create data subdirectories
    for c in TRAINING_COLS:
        os.makedirs(os.path.join("data", c), exist_ok=True)
    
    # Set up Dask job
    lazy_results = []
    for c in TRAINING_COLS:
        for uri in uris[c]:
            lazy_result = dask.delayed(process_lc)(uri, c)
            lazy_results.append(lazy_result)
    
    # Run Dask job
    wavelets = dask.compute(*lazy_results)

### 2c. Build a container for the training data

PyTorch gives users the ability to train in batches using their Dataset and DataLoader classes. Training with the right batch size helps the optimizer to find minima more quickly in the loss space, which can accelerate training. Here we will set up a custom dataset class to take advantage of batch training. 

In [None]:
class WaveletDataset(Dataset):
    """Custom dataset class for loading wavelet data from files.

    This class is responsible for loading wavelet data and corresponding labels
    from the specified file paths. It supports splitting the data into training,
    validation, and test sets.

    Attributes:
        features (np.ndarray): Array containing the light curve wavelet transforms.
        labels (np.ndarray): Array containing the classifications for each light curve.
        columns (list of str): List of names of classification columns.
    """    
    def __init__(self, data, columns):
        """
        Initialize the dataset.
        
        Args:
            data (DataFrame): the DataFrame containing the wavelet and classification data.
                Wavelet data should be in a column labeled 'wavelet', while the classification
                columns must be specified by the `columns` argument.
            columns (list of str): list of names of classification columns to use.
        """ 
        self.features = data["wavelet"].values
        self.labels = data[columns].values
        self.columns = columns

    def __len__(self):
        """Return the length of the dataset."""
        return len(self.labels)

    def __getitem__(self, idx):
        """Retrieve a single sample and its corresponding label.

        Args:
            idx (int): The index of the sample to retrieve.

        Returns:
            tuple: A tuple containing the sample data (torch.Tensor) and the 
                   corresponding label (torch.Tensor).
        """
        X = self.features[idx].astype("float32")
        X = torch.unsqueeze(torch.tensor(X), 0)
        label = torch.tensor(self.labels[idx].astype("float32"))
        return X, label

Now we will put the training data into a DataFrame before passing to the WaveletDataset class.

In [None]:
# Set up dataframe
training_data = pd.DataFrame([], columns=["TIC", "sector", *TRAINING_COLS, "wavelet"])
training_data.index.name = "filename"

# Iterate over wavelet files, adding them to the dataframe
for i, col in enumerate(TRAINING_COLS):
    for s in os.listdir(os.path.join("data", col)):
        if not s.endswith(".npy"):
            continue
        if s not in training_data.index:
            wav = np.load(os.path.join("data", col, s))
            tic_id, sector = s.split("_")[4].split("-") # split target name field
            tic_id = int(tic_id)
            sector = int(sector[1:]) # trim off leading 's'
            training_data.loc[s] = [tic_id, sector, *([0]*len(TRAINING_COLS)), wav/wav.max()]
        training_data.loc[s, col] = 1

training_data = training_data.reset_index().set_index(["TIC", "sector"]).sort_index()

# Normalize the class identification
training_data[TRAINING_COLS] = training_data[TRAINING_COLS].div(training_data[TRAINING_COLS].sum(axis=1), axis=0)

# Filter out rows that have multiple classes
training_data = training_data[training_data[TRAINING_COLS].max(axis=1) == 1]

# grab top 1000 rows for each class
training_data = pd.concat([training_data.query(f"{c} == 1").iloc[:1000] for c in TRAINING_COLS]).sort_index()

# join with light curve URIs from earlier
training_data = training_data.join(uri_df)
training_data

In [None]:
# Cast DataFrame to WaveletDataset
wavelet_data = WaveletDataset(training_data, columns=TRAINING_COLS)

# Randomly partition the data between training, validation, and test sets
generator1 = torch.Generator().manual_seed(42)
train_set, validation_set, test_set = random_split(wavelet_data, lengths=TRAIN_SPLIT, generator=generator1)

# Cast datasets to DataLoaders for training in batches
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE)
validation_loader = DataLoader(validation_set, batch_size=BATCH_SIZE)
test_loader = DataLoader(test_set, batch_size=BATCH_SIZE)

## 3. Build the CNN Classifier

The CNN exhibits the following attributes:
- 2D convolution layers to extract features from wavelet images
- Configurable number of convolution filters
- 1D max-pooling in the time dimension for dimensionality reduction while preserving frequency resolution
- 2D batch normalization to keep weights stable
- ReLU activation to keep values non-negative
- Dropout to build in redundancy
- Softmax output for classification

In [None]:
class ConvNet(nn.Module):
    def __init__(self, c=[8, 16, 32], k=3, n_output=4):
        if isinstance(k, int):
            k = [k]*len(c)
        n_nodes = (64 - (sum(k) - len(k))) * c[-1]

        super(ConvNet, self).__init__()
        self.conv1 = nn.Conv2d(1,  c[0], k[0], 1) # 62 x 20
        self.conv1_bn = nn.BatchNorm2d(c[0])
        self.conv2 = nn.Conv2d(c[0], c[1], k[1], 1) # 60 x 6 
        self.conv2_bn = nn.BatchNorm2d(c[1])
        self.conv3 = nn.Conv2d(c[1], c[2], k[2], 1) # 58 x 1
        self.conv3_bn = nn.BatchNorm2d(c[2])
        
        self.fc1 = nn.Linear(n_nodes, 256) # 58 x 32 = 1856
        self.fc1_bn = nn.BatchNorm1d(256)
        self.fc2 = nn.Linear(256, 64)
        self.fc2_bn = nn.BatchNorm1d(64)
        self.fc3 = nn.Linear(64, n_output)
        
        self.dropout1 = nn.Dropout(0.1)
        self.dropout2 = nn.Dropout2d(0.1)

    def forward(self, x, return_embeddings=False):
        x = self.conv1(x)
        x = self.conv1_bn(x)
        x = F.relu(x)
        x = F.max_pool2d(x, (1, 3))
        x = self.dropout2(x)

        x = self.conv2(x)
        x = self.conv2_bn(x)
        x = F.relu(x)
        x = F.max_pool2d(x, (1, 3))
        x = self.dropout2(x)

        x = self.conv3(x)
        x = self.conv3_bn(x)
        x = F.relu(x)
        x = F.max_pool2d(x, (1, x.shape[-1]))
        x = self.dropout2(x)

        x = torch.flatten(x, 1)

        x = self.fc1(x)
        x = self.fc1_bn(x)
        x = F.relu(x)
        x = self.dropout1(x)

        x = self.fc2(x)
        x = self.fc2_bn(x)

        if return_embeddings:
            return x
        
        x = F.relu(x)
        x = self.dropout1(x)
        
        output = self.fc3(x)
        output = F.softmax(output, dim=1)
        return output

## 4. Train the CNN

We train the CNN using the `Adam` optimizer, which uses adaptive learning rates (LR) to train the network. To vary the LR, we use a plateau scheduler (`ReduceLROnPlateau`), which reduces the LR when the loss plateaus. This enables the CNN parameters to find local minima more easily, rather than take large steps over them.

Finally, we will also use an early stopping criterion. This means that if the validation loss plateaus or increases for a certain number of epochs, training stops early, and the best fit CNN values are saved.

First we will define our training function.

In [None]:
def train(model, device, train_loader, val_loader, patience, max_epochs, model_name="cnn"):
    """Train the neural network for the specified number of epochs.

    This function orchestrates the training loop, updating model weights based on
    the training data, validating the model on a validation set, and handling
    early stopping based on validation loss.

    Args:
        model (torch.nn.Module): The neural network model to be trained.
        device (torch.device): The device (CPU or GPU) on which to perform training.
        train_loader (DataLoader): DataLoader for the training dataset.
        val_loader (DataLoader): DataLoader for the validation dataset.
        patience (int): Early stopping patience.
        max_epochs (int): Maximum number of training iterations.
        model_name (str): The name of the model, used for saving the trained weights.

    Returns:
        tuple: A tuple containing the best model weights, training losses, and validation losses.
    """
    optimizer = optim.Adam(model.parameters(), lr=1e-5)
    scheduler = ReduceLROnPlateau(optimizer, factor=0.7, patience=3)
    
    train_losses, val_losses = [], []
    min_loss = float("inf")
    early_stopping_count = 0
    best_weights = None

    for epoch in range(1, max_epochs + 1):
        train_loss = train_epoch(model, device, train_loader, optimizer, epoch)
        val_loss = test(model, device, val_loader, epoch, model_name, mode="Validation")
        
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        scheduler.step(val_loss)

        if val_loss < min_loss:
            min_loss = val_loss
            early_stopping_count = 0
            best_weights = deepcopy(model.state_dict())
        else:
            early_stopping_count += 1
            print(f"        Early Stopping count: {early_stopping_count}/{patience}")
            if early_stopping_count == patience:
                print(f"\nEarly Stopping. Best Epoch: {epoch - patience} with loss {min_loss:.4f}.")
                break

    return best_weights, train_losses, val_losses


def train_epoch(model, device, train_loader, optimizer, epoch):
    """Train the model for one epoch.

    This function processes each batch of training data, computes the loss,
    and updates the model weights accordingly.

    Args:
        model (torch.nn.Module): The neural network model to be trained.
        device (torch.device): The device (CPU or GPU) for training.
        train_loader (DataLoader): DataLoader for the training dataset.
        optimizer (torch.optim.Optimizer): The optimizer used for weight updates.
        epoch (int): The current epoch number.

    Returns:
        float: The average loss for the epoch.
    """
    model.train()
    losses = []

    ndata = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = LOSS_FUNCTION(output, target)
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
        ndata += len(data)
        print(
            f"Train Epoch: {epoch:3d} [{ndata:6d}/{len(train_loader.dataset)}"
            f" ({100*ndata/len(train_loader.dataset):3.0f}%)]\tLoss: {losses[-1]:.6f}",
            end="\r")

    return np.mean(losses)


def test(model, device, test_loader, epoch=None, model_name=None, mode="Validation"):
    """Evaluate the model on the test set.

    This function assesses the model's performance on a specified dataset
    and computes the average loss. It can also generate a plot of predictions
    versus true values.

    Args:
        model (torch.nn.Module): The neural network model to be evaluated.
        device (torch.device): The device (CPU or GPU) for evaluation.
        test_loader (DataLoader): DataLoader for the test dataset.
        epoch (int, optional): The current epoch number (for labeling purposes).
        model_name (str, optional): The name of the model (for labeling purposes).
        mode (str, optional): Indicates whether the evaluation is for training, validation, or testing.

    Returns:
        float: The average loss on the test set.
    """
    model.eval()
    test_loss = 0
    targets, preds = [], []
    
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            targets.extend(target.cpu().numpy())
            preds.extend(output.cpu().numpy())
            test_loss += LOSS_FUNCTION(output, target).item()

    test_loss /= len(test_loader)
    print(f"\n                     Average {mode} Loss: {test_loss:.4f}")

    if mode.lower() == "test":
        return np.squeeze(preds), np.squeeze(targets), test_loss
    return test_loss

If `RETRAIN_CNN==True`, we will train the CNN here. Otherwise, we will load the default weights from file.

In [None]:
model = ConvNet(selected_channels, k=3, n_output=len(TRAINING_COLS)).to(DEVICE)

if RETRAIN_CNN:
    # Train the CNN
    # weights, train_losses, validation_losses = train(model, DEVICE, train_loader, validation_loader, PATIENCE, MAX_EPOCHS, model_name)
    torch.save(weights, os.path.join("runs", model_name, f"{model_name}.pt"))
    np.save(
        os.path.join("runs", model_name, f"{model_name}_loss.npy"), 
        np.stack([train_losses, validation_losses])
    )
else:
    # Load weights from file
    weights = torch.load(f"runs/{model_name}/{model_name}.pt")
    train_losses, validation_losses = np.load(os.path.join("runs", model_name, f"{model_name}_loss.npy"))
    
# Evaluate best-fit model
model.load_state_dict(weights)
print("\nFinal Performance!")
test(model, DEVICE, train_loader, mode="Training")
test(model, DEVICE, validation_loader, mode="Validation")
test_preds, test_labels, test_loss = test(model, DEVICE, test_loader, mode="Test")

Now that the CNN is trained, we should take a look at the "learning curves," or the loss over time.

In [None]:
# Plot learning curve
plt.figure()
plt.plot(train_losses, label="Training Loss")
plt.plot(validation_losses, label="Validation Loss", alpha=0.5)
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.show()

While `MAX_EPOCHS` was set to 500 by default, the training function checked the value of the validation loss to see if it plateaued for 30 epochs. When it did, training was declared to be finished, and the best-fit weights were saved.

## 5. Evaluate CNN Performance

Now that training is complete, it's useful to evaluate the CNN on the held-out test set to see how accurately it classifies data that it hasn't seen before.

In [None]:
# Join test set predictions with the input data
test_df = training_data.iloc[test_set.indices].copy()
test_df.loc[:, [c + "_pred" for c in wavelet_data.columns]] = test_preds
test_df

In [None]:
# Take the prediction for each light curve to be the class with the highest probability
prediction = test_df[[c + "_pred" for c in TRAINING_COLS]].values.argmax(axis=1)
truth = test_df[TRAINING_COLS].values.argmax(axis=1)

# Create a confusion matrix to look at accuracy and false positive/negatives
cm = confusion_matrix(truth, prediction)
cm = cm.astype(float)/cm.sum(axis=1)[:, np.newaxis]

# Visualize confusion matrix
plt.figure(figsize=(8, 6))
seaborn.heatmap(cm, annot=True, fmt=".2f", cmap="Blues", xticklabels=TRAINING_COLS, yticklabels=TRAINING_COLS)
plt.xlabel("Predicted Labels")
plt.ylabel("True Labels")
plt.title("Confusion Matrix")
plt.show()

The example CNN classifies unseen data with 77% accuracy or better. The failure modes are interesting:
- True EBs are often misclassified as (a) exoplanet transits or (b) rotators. This is likely because (a) the shape of a binary eclipse resembles that of a planet transit, and (b) EBs are often rapidly rotating due to binary interactions, so they may also exhibit rotation signatures.
- While transiting exoplanets are accurately classified 86% of the time, the failure modes are all equally likely.
- Flaring stars are not likely to be misclassified as rotators, but rotators are likely to be misclassified as flaring stars. Rapid rotators often also exhibit flares, so the latter isn't too surprising, but it's interesting that the converse doesn't occur as well.

All in all, the predictions make physical sense, but definitely require some digging to understand.

## 6. Extract Embeddings and Conduct the Similarity Search

Now that we have verified that the CNN is properly trained, we can extract the feature embeddings and perform the similarity search. We'll first get the embedding vectors of the test set, then pick a random light curve, and then get the 10 most similar ones.

In [None]:
if RETRAIN_CNN:
    # We only need to extract the embeddings if the CNN was retrained
    test_data = wavelet_data.features[test_set.indices]
    tdata = torch.tensor(np.stack(test_data), dtype=torch.float32).unsqueeze(1)
    
    model.eval()
    with torch.no_grad():
        # We built the CNN with a special clause to export the embeddings during forward propagation
        embeddings = model.forward(tdata, return_embeddings=True).numpy()

    # Save the embeddings to file
    np.save(os.path.join("runs", model_name, f"{model_name}_embeddings.npy"), embeddings)
else:
    # Otherwise we just read the old ones from file
    embeddings = np.load(os.path.join("runs", model_name, f"{model_name}_embeddings.npy"))

# Add the embeddings to the test DataFrame
test_df.loc[:, "embedding"] = list(embeddings)
test_df

In [None]:
# For our "test target", we'll pop the first row of the test DataFrame.
my_row = test_df.iloc[0]
popped_df = test_df.drop(my_row.name)

In [None]:
# This cell contains the actual similarity search. The steps are:
#  1. Recall the embeddings,
#  2. Compute the distance between the test point and all other embedding vectors
#  3. Return the closest 10 targets

points = np.stack(popped_df["embedding"].to_numpy())
dists = np.linalg.norm(points - my_row["embedding"], axis=1)

popped_df["dist"] = dists
closest = popped_df.nsmallest(10, "dist")
closest

### Visualize "similar" light curves

Now that we have performed the similarity search, we should visualize the returned light curves to decide whether or not we believe that they are similar.

In [None]:
# Read light curves
for i, row in closest.iterrows():
    closest.loc[i, "lightcurve"] = lk.read(training_data.loc[i, "uri"])

my_lc = lk.read(training_data.loc[my_row.name, "uri"])

In [None]:
# Plot the "target" light curve
ax = my_lc.plot()
ax.legend().remove()
ax.set_title(f"TIC {my_lc.ticid}, Sector {my_lc.sector}")
plt.show()

The target light curve appears to be an eclipsing binary or an ellipsoidal variable (note that we didn't have a separate class for ellipsoidal variations). But how do the "similar" light curves compare?

In [None]:
# Plot the top 10 most similar light curves
for i, row in closest.iterrows():
    lc = row["lightcurve"]
    ax = lc.plot()
    ax.legend().remove()
    ax.set_title(f"TIC {lc.ticid}, Sector {lc.sector}")
plt.show()

All 10 of the returned similar light curves appear to have the same kinds of ellipsoidal variations and/or binary eclipses. Remember that the similarity search is agnostic to the actual classifications---it's performed on extracted features of the light curves alone!

## Summary

In this notebook we developed a light curve similarity search using a CNN classifier to learn and export feature embeddings from the light curves. We selected and processed training data, built and trained the CNN, extracted the embeddings, and showed an example of a similarity search. Similarity Searches are powerful because they do not rely on metadata in the same way classical database queries do. Instead, they rely only on features of the data. This kind of search can empower users to build science samples without needing to know *a priori* what kinds of objects produce the features they see in the data.

## Exercises

In this example, we used a single CNN configuration to build the similarity search, but there are other configurations and choices that may be enlightening. For example,
1. Use different numbers of convolution filters by changing `RUN_NUMBER`, or if you're feeling daring, manually configure the CNN `channels`.
2. Train on fewer training classes by removing a column from `TRAINING_COLS`. You may find that adding or removing classes affects the classification accuracy.
3. Use different "test" examples for the similarity search. Do the returned light curves look similar?